import argparse
import numpy as np
import os
import torch
import json
from collections import Counter
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import re

def extract_answer(solution_text: str):
    boxed_pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(boxed_pattern, solution_text)
    if matches:
        return matches[-1].strip()
    return None

def apply_chat_template(toker, messages):
    input_prompt = toker.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    return toker(input_prompt, add_special_tokens=False).input_ids

def prepare_input_boxed(template, input_data, toker):
    inputs = []
    checks = []
    for input_d in input_data:
        id_ = input_d['id']
        problem = input_d['problem']
        steps = input_d['steps']
        label = input_d['label']
        
        previous_steps = ''
        for sdx, step in enumerate(steps):
            current_step = f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n'
            if sdx==0:
                prompt = template.replace("[Previous Paragraph(s)]\n\n{tagged_response}\n\n", "").format(problem=problem, current_paragraph=current_step)
                inputs.append(prompt.strip())
                checks.append({'id': id_, 'label': label})
            else:
                prompt = template.format(problem=problem, tagged_response=previous_steps, current_paragraph=current_step)
                inputs.append(prompt.strip())
                checks.append({'id': id_, 'label': label})
            previous_steps += current_step
    results = []
    for item in inputs:
        messages = [{'role': 'user', 'content': item}]
        results.append(apply_chat_template(toker, messages))

    return results, checks

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--configs', type=str, nargs='+', default=None,
                        choices=['gsm8k', 'math', 'olympiadbench', 'omnimath'])
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument("--output_dir", type=str, default='./process_critique_correctness_outputs')
    parser.add_argument("--model_cache_dir", type=str, default='')
    parser.add_argument('--use_voting', action='store_true')
    parser.add_argument('--voting_n', type=int, default=8)
    args = parser.parse_args()

    args.model_name = os.path.basename(args.model_path)

    toker = AutoTokenizer.from_pretrained(args.model_path)
    TEMPLATE = open('./templates/process_critique_correctness.txt').read().strip()


    llm = LLM(
        model=args.model_path, tokenizer=args.model_path,
        gpu_memory_utilization=0.95,
        tensor_parallel_size=torch.cuda.device_count(),
        enable_prefix_caching=True, swap_space=16,
        max_num_seqs=20,
        download_dir=args.model_cache_dir
    )
    if not args.use_voting:
        sampling_params = SamplingParams(temperature=0.,
                                         max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
    else:
        if 'Qwen2.5-Math' in args.model_path: # to ensure normal generation of Qwen2.5-Math-7B/72B-Instruct
            sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, n=args.voting_n,
                                            max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
        else:
            sampling_params = SamplingParams(temperature=1, top_p=0.9, n=args.voting_n,
                                            max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)


    if args.configs is None:
        args.configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']

    for config in args.configs:
        if not args.use_voting:
            output_dir = os.path.join(args.output_dir, args.model_name)
        else:
            output_dir = os.path.join(args.output_dir, f'{args.model_name}_voting')
        os.makedirs(output_dir, exist_ok=True)

        input_data = load_dataset('Qwen/ProcessBench', split=config)

        prompt_token_ids, checks = prepare_input_boxed(TEMPLATE, input_data, toker)
        # prompt_token_ids = [apply_chat_template(toker, prepare_input_boxed(TEMPLATE, e))
                            # for e in input_data]

        generations = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)

        
        preds = []
        critiques = []
        for i in range(len(checks)):
            if not args.use_voting:
                generated_critique = generations[i].outputs[0].text
                pred = extract_answer(generated_critique)
                try:
                    pred = int(pred)
                except:
                    pred = None
            else:
                generated_critique = [ee.text for ee in generations[i].outputs]
                voting_preds = [extract_answer(e) for e in generated_critique]
                voting_preds = [e for e in voting_preds if e is not None]
                
                if len(voting_preds) == 0:
                    pred = None
                else:
                    pred = Counter(voting_preds).most_common(1)[0][0]
                    try:
                        pred = int(pred)
                    except:
                        pred = None
            critiques.append(generated_critique)
            preds.append(pred)
        

        # Restructure predictions into nested list based on number of steps per instance
        nested_preds = []
        nested_critiques = []
        pred_idx = 0
        for i in range(len(input_data)):
            num_steps = len(input_data[i]['steps'])
            instance_preds = preds[pred_idx:pred_idx + num_steps]
            instance_critiques = critiques[pred_idx:pred_idx + num_steps]
            nested_preds.append(instance_preds)
            nested_critiques.append(instance_critiques)
            pred_idx += num_steps
        preds = nested_preds
        critiques = nested_critiques

        

        res_data = []
        for i in range(len(input_data)):
            d = input_data[i].copy()
            pred_list = preds[i]
            critiques_list = critiques[i]
            try:
                # Find first occurrence of 0, or -1 if all 1s
                pred = next((idx for idx, val in enumerate(pred_list) if val == 0), -1)
            except:
                pred = None

            d['generated_critique'] = critiques_list
            d['prediction_list'] = pred_list
            d['prediction'] = pred
            d['match'] = (pred == d['label'])

            res_data.append(d)


        error_data = [e for e in res_data if e['label'] != -1]
        correct_data = [e for e in res_data if e['label'] == -1]

        with open(os.path.join(output_dir, f'{config}_error.jsonl'), 'w') as f:
            for e in error_data:
                f.write(json.dumps(e) + '\n')
        with open(os.path.join(output_dir, f'{config}_correct.jsonl'), 'w') as f:
            for e in correct_data:
                f.write(json.dumps(e) + '\n')
        
        acc1 = np.mean([e['match'] for e in error_data]) * 100
        acc2 = np.mean([e['match'] for e in correct_data]) * 100
        f1 = 2 * acc1 * acc2 / (acc1 + acc2)
        print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')


if __name__ == '__main__':
    main()