import argparse
import numpy as np
import os
import torch
import json
from collections import Counter, defaultdict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import re
import math

total_instances_dict = {
    "aime_Eurus2_7B_sft": 240,
    "amc_Eurus2_7B_sft": 664,
    "gpqa_Eurus2_7B_sft": 1584,
    "leetcode_Eurus2_7B_sft": 1440,
    "math_Eurus2_7B_sft": 4000,
    "minerva_math_Eurus2_7B_sft": 2176,
    "olympiadbench_Eurus2_7B_sft": 5400,
    "aime_Meta_Llama_3.1_70B_Instruct": 240,
    "amc_Meta_Llama_3.1_70B_Instruct": 664,
    "gpqa_Meta_Llama_3.1_70B_Instruct": 1584,
    "leetcode_Meta_Llama_3.1_70B_Instruct": 1440,
    "math_Meta_Llama_3.1_70B_Instruct": 4000,
    "minerva_math_Meta_Llama_3.1_70B_Instruct": 2176,
    "olympiadbench_Meta_Llama_3.1_70B_Instruct": 5400,
    "aime_Qwen2.5_7B_Instruct": 240,
    "amc_Qwen2.5_7B_Instruct": 664,
    "gpqa_Qwen2.5_7B_Instruct": 1584,
    "leetcode_Qwen2.5_7B_Instruct": 1440,
    "math_Qwen2.5_7B_Instruct": 4000,
    "minerva_math_Qwen2.5_7B_Instruct": 2176,
    "olympiadbench_Qwen2.5_7B_Instruct": 5400,
}

def extract_answer(solution_text: str):
    boxed_pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(boxed_pattern, solution_text)
    if matches:
        return matches[-1].strip()
    return None

def apply_chat_template(toker, messages):
    input_prompt = toker.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    return toker(input_prompt, add_special_tokens=False).input_ids

def prepare_input_boxed(template, input_d):
    problem = input_d['problem']
    steps = input_d['steps']
    tagged_response = ''
    for sdx, step in enumerate(steps):
        tagged_response += f'<paragraph_{str(int(sdx)+1)}>\n{step}\n</paragraph_{str(int(sdx)+1)}>\n\n'
    tagged_response = tagged_response.strip()
    prompt = template.format(problem=problem, tagged_response=tagged_response)
    messages = [{'role': 'user', 'content': prompt}]
    return messages

def get_logprobs(generation, toker, use_softmax=False, max_steps=10):
    # Define token IDs
    token_ids = {
        'boxed': 79075,  # Token ID for "boxed"
        'left_brace': 90,  # Token ID for "{"
        'right_brace': 92,  # Token ID for "}",
        'new_right_brace': 31716,  # Another Token ID for "}",
        'another_right_brace': 532,
        "another_right_brace2": 7810
    }

    all_tokens = generation.outputs[0].token_ids
    
    # Find all occurrences of boxed{ followed by numbers and }
    pattern_positions = []
    
    for i in range(len(all_tokens)-2):  # Need at least boxed and {
        if all_tokens[i] == token_ids['boxed'] and all_tokens[i+1] == token_ids['left_brace']:
            # Look for sequence of numbers followed by a right brace
            j = i + 2
            num_tokens = []
            while j < len(all_tokens):
                if all_tokens[j] in [15, 16]:  # Only check for 0 and 1 tokens
                    num_tokens.append(all_tokens[j])
                    j += 1
                elif all_tokens[j] in [token_ids['right_brace'], token_ids['new_right_brace'], 
                                   token_ids['another_right_brace'], token_ids['another_right_brace2']]:
                    # Found valid pattern
                    pattern_positions.append({
                        'index': i,
                        'number_tokens': num_tokens
                    })
                    break
                else:
                    break

    # If no valid patterns found
    if not pattern_positions:
        if use_softmax:
            return 0, 0.0
        else:
            return 0, float('-inf')
            
    # Use the rightmost pattern
    rightmost = pattern_positions[-1]
    boxed_start_idx = rightmost['index'] + 2  # Points to first number position
            
    # Get logprobs from generation output
    logprobs = generation.outputs[0].logprobs[boxed_start_idx]
    
    # Get logprob for "0" and "1"
    zero_logprob = float('-inf')
    one_logprob = float('-inf')
    
    for token_id, prob in logprobs.items():
        if token_id == 15:  # Token for "0"
            zero_logprob = prob.logprob
        elif token_id == 16:  # Token for "1"
            one_logprob = prob.logprob
    
    if use_softmax:
        # Apply softmax between zero and one
        max_logprob = max(zero_logprob, one_logprob)
        zero_exp = math.exp(zero_logprob - max_logprob)
        one_exp = math.exp(one_logprob - max_logprob)
        sum_exp = zero_exp + one_exp
        
        one_prob = one_exp / sum_exp
        return 1 if one_prob > 0.5 else 0, one_prob
    else:
        # Return logprob for 1
        return 1 if one_logprob > zero_logprob else 0, one_logprob

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--configs', type=str, nargs='+', default=None,
                        choices=["aime_Eurus2_7B_sft", "amc_Eurus2_7B_sft", "gpqa_Eurus2_7B_sft", "leetcode_Eurus2_7B_sft", "math_Eurus2_7B_sft", "minerva_math_Eurus2_7B_sft", "olympiadbench_Eurus2_7B_sft", "aime_Meta_Llama_3.1_70B_Instruct", "amc_Meta_Llama_3.1_70B_Instruct", "gpqa_Meta_Llama_3.1_70B_Instruct", "leetcode_Meta_Llama_3.1_70B_Instruct", "math_Meta_Llama_3.1_70B_Instruct", "minerva_math_Meta_Llama_3.1_70B_Instruct", "olympiadbench_Meta_Llama_3.1_70B_Instruct", "aime_Qwen2.5_7B_Instruct", "amc_Qwen2.5_7B_Instruct", "gpqa_Qwen2.5_7B_Instruct", "leetcode_Qwen2.5_7B_Instruct", "math_Qwen2.5_7B_Instruct", "minerva_math_Qwen2.5_7B_Instruct", "olympiadbench_Qwen2.5_7B_Instruct"])
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument("--output_dir", type=str, default='./bon_outputs')
    parser.add_argument("--model_cache_dir", type=str, default='')
    parser.add_argument('--use_voting', action='store_true')
    parser.add_argument('--voting_n', type=int, default=8)
    parser.add_argument('--resume', action='store_true', help='Resume from checkpoint if exists')
    parser.add_argument('--use_softmax', action='store_true')
    args = parser.parse_args()

    args.model_name = os.path.basename(args.model_path)

    toker = AutoTokenizer.from_pretrained(args.model_path)
    TEMPLATE = open('./templates/outcome_cot_critique.txt').read().strip()

    llm = LLM(
        model=args.model_path, tokenizer=args.model_path,
        gpu_memory_utilization=0.9,
        tensor_parallel_size=torch.cuda.device_count(),
        enable_prefix_caching=True, swap_space=16,
        max_num_seqs=20,
        download_dir=args.model_cache_dir
    )
    
    sampling_params = SamplingParams(temperature=0.6, logprobs=20, max_tokens=32768, seed=42)

    if args.configs is None:
        args.configs = ["aime_Eurus2_7B_sft", "amc_Eurus2_7B_sft", "gpqa_Eurus2_7B_sft", "leetcode_Eurus2_7B_sft", "math_Eurus2_7B_sft", "minerva_math_Eurus2_7B_sft", "olympiadbench_Eurus2_7B_sft", "aime_Meta_Llama_3.1_70B_Instruct", "amc_Meta_Llama_3.1_70B_Instruct", "gpqa_Meta_Llama_3.1_70B_Instruct", "leetcode_Meta_Llama_3.1_70B_Instruct", "math_Meta_Llama_3.1_70B_Instruct", "minerva_math_Meta_Llama_3.1_70B_Instruct", "olympiadbench_Meta_Llama_3.1_70B_Instruct", "aime_Qwen2.5_7B_Instruct", "amc_Qwen2.5_7B_Instruct", "gpqa_Qwen2.5_7B_Instruct", "leetcode_Qwen2.5_7B_Instruct", "math_Qwen2.5_7B_Instruct", "minerva_math_Qwen2.5_7B_Instruct", "olympiadbench_Qwen2.5_7B_Instruct"]

    for config in args.configs:
        output_dir = os.path.join(args.output_dir, args.model_name+"_cot_outcome_setting")
        os.makedirs(output_dir, exist_ok=True)
        
        checkpoint_file = os.path.join(output_dir, f'{config}_checkpoint.jsonl')
        
        # Load checkpoint if exists and resume flag is set
        processed_instances = {}
        if os.path.exists(checkpoint_file) and args.resume:
            with open(checkpoint_file) as f:
                for line in f:
                    data = json.loads(line)
                    instance_id = data['id'].rsplit('/', 1)[0]
                    processed_instances[instance_id] = data
                    
            # Load results from checkpoint
            results = {}
            for base_id, data in processed_instances.items():
                instance_idx = int(data['id'].rsplit('/', 1)[1])
                if base_id not in results:
                    results[base_id] = {
                        'critiques': [None for _ in range(16)],  # Assuming max 16 instances per group
                        'logprobs': [None for _ in range(16)],
                        'predictions': [None for _ in range(16)],
                        'has_duplicate_steps': [False for _ in range(16)]
                    }
                results[base_id]['critiques'][instance_idx] = data['critique']
                results[base_id]['logprobs'][instance_idx] = data['logprob']
                results[base_id]['predictions'][instance_idx] = data['prediction']
                results[base_id]['has_duplicate_steps'][instance_idx] = data.get('has_duplicate_steps', False)

        input_data = load_dataset('prometheus-eval/filtered_bon_setting')[config]
        
        # Group instances by base id (before last /)
        instance_groups = defaultdict(list)
        for item in input_data:
            base_id = item['id'].rsplit('/', 1)[0]
            if base_id not in processed_instances:
                # Check for duplicate steps
                steps = item['steps']
                has_duplicates = len(steps) != len(set(steps))
                item['has_duplicate_steps'] = has_duplicates
                instance_groups[base_id].append(item)
                
        if not args.resume:
            results = {}
            
        # Initialize results for all instances
        for base_id, instances in instance_groups.items():
            if base_id not in results:
                results[base_id] = {
                    'critiques': [None for _ in instances],
                    'logprobs': [None for _ in instances],
                    'predictions': [None for _ in instances],
                    'has_duplicate_steps': [inst['has_duplicate_steps'] for inst in instances]
                }

        # Process instances without duplicates
        active_instances = []
        for base_id, instances in instance_groups.items():
            active_instances.extend([inst for inst in instances if not inst['has_duplicate_steps']])

        # Process unprocessed instances
        current_instances = []
        for instance in active_instances:
            base_id = instance['id'].rsplit('/', 1)[0]
            instance_idx = int(instance['id'].rsplit('/', 1)[1])
            if results[base_id]['critiques'][instance_idx] is None:
                current_instances.append(instance)

        if current_instances:
            # Prepare input for all instances
            prompt_token_ids_list = []
            instance_map = []  # Keep track of which instance each prompt corresponds to
            
            for instance in current_instances:
                messages = prepare_input_boxed(TEMPLATE, instance)
                prompt_ids = apply_chat_template(toker, messages)
                if prompt_ids:
                    prompt_token_ids_list.append(prompt_ids)
                    instance_map.append(instance)

            if prompt_token_ids_list:
                # Generate for all instances at once
                generations = llm.generate(prompt_token_ids=prompt_token_ids_list, 
                                        sampling_params=sampling_params)

                # Process results
                for instance, generation in zip(instance_map, generations):
                    base_id = instance['id'].rsplit('/', 1)[0]
                    instance_idx = int(instance['id'].rsplit('/', 1)[1])
                    
                    prediction, logprob = get_logprobs(generation, toker, args.use_softmax, len(instance['steps']))

                    results[base_id]['critiques'][instance_idx] = generation.outputs[0].text
                    results[base_id]['logprobs'][instance_idx] = logprob
                    results[base_id]['predictions'][instance_idx] = prediction

                # Save checkpoint after processing all instances
                with open(checkpoint_file, 'w') as f:
                    for res_base_id, res in results.items():
                        for idx, (critique, logprob, prediction, has_duplicate) in enumerate(zip(
                            res['critiques'], res['logprobs'], res['predictions'], res['has_duplicate_steps'])):
                            if critique is not None:  # Only save processed instances
                                instance = next(i for i in instance_groups[res_base_id] if i['id'].endswith(f'/{idx}'))
                                data = {
                                    'id': instance['id'],
                                    'critique': critique,
                                    'logprob': logprob,
                                    'prediction': prediction,
                                    'has_duplicate_steps': has_duplicate
                                }
                                f.write(json.dumps(data) + '\n')

        # Process final results for different N values after all steps are done
        for N in [1, 2, 4, 8]:
            final_results = []
            for base_id, res in results.items():
                instances = instance_groups[base_id][:N]  # Only consider first N instances

                # Find best response
                best_idx = -1
                best_logprob = float('-inf')

                # Find response with highest logprob
                for idx in range(len(instances)):
                    logprob = res['logprobs'][idx]
                    if logprob is not None:
                        print(logprob)
                        if logprob > best_logprob:
                            best_logprob = logprob
                            best_idx = idx

                # If no response found, use index 0
                if best_idx == -1:
                    best_idx = 0

                # Save all responses with their metrics
                all_responses = []
                for idx, instance in enumerate(instances):
                    response_data = {
                        'instance': instance,
                        'generated_critique': res['critiques'][idx],
                        'logprobs': res['logprobs'][idx],
                        'predictions': res['predictions'][idx],
                        'is_best_response': (idx == best_idx)
                    }
                    all_responses.append(response_data)
                final_results.extend(all_responses)

            # Count correct final answers from best responses
            correct_count = sum(1 for instance in final_results if instance['is_best_response'] and instance['instance'].get('final_answer_correct', False))
            total_instances = total_instances_dict[config]/8
            score = correct_count / total_instances if total_instances > 0 else 0
            print(f"{config} N={N} Score: {score:.4f} ({correct_count}/{total_instances})")

            # Save all results for this N value
            with open(os.path.join(output_dir, f'{config}_results_N={N}_score={str(score)}.json'), 'w') as f:
                json.dump(final_results, f, indent=4)

        if os.path.exists(checkpoint_file):
            os.remove(checkpoint_file)

if __name__ == '__main__':
    main()