import argparse
import numpy as np
import os
import torch
import json
from collections import Counter
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import re

def extract_answer(solution_text: str):
    boxed_pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(boxed_pattern, solution_text)
    if matches:
        return matches[-1].strip()
    return None

def apply_chat_template(toker, messages):
    input_prompt = toker.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    return toker(input_prompt, add_special_tokens=False).input_ids

def prepare_input_boxed(template, input_d):
    problem = input_d['problem']
    steps = input_d['steps']
    tagged_response = ''
    for sdx, step in enumerate(steps):
        tagged_response += f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n'
    tagged_response = tagged_response.strip()
    prompt = template.format(problem=problem, tagged_response=tagged_response)
    messages = [{'role': 'user', 'content': prompt}]
    return messages

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--configs', type=str, nargs='+', default=None,
                        choices=['gsm8k', 'math', 'olympiadbench', 'omnimath'])
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument("--output_dir", type=str, default='./outputs')
    parser.add_argument("--model_cache_dir", type=str, default='')
    parser.add_argument('--use_voting', action='store_true')
    parser.add_argument('--voting_n', type=int, default=8)
    args = parser.parse_args()

    args.model_name = os.path.basename(args.model_path)

    toker = AutoTokenizer.from_pretrained(args.model_path)
    TEMPLATE = open('./templates/critique_template.txt').read().strip()


    llm = LLM(
        model=args.model_path, tokenizer=args.model_path,
        gpu_memory_utilization=0.95,
        tensor_parallel_size=torch.cuda.device_count(),
        enable_prefix_caching=True, swap_space=16,
        max_num_seqs=20,
        download_dir=args.model_cache_dir
    )
    if not args.use_voting:
        sampling_params = SamplingParams(temperature=0.,
                                         max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
    else:
        if 'Qwen2.5-Math' in args.model_path: # to ensure normal generation of Qwen2.5-Math-7B/72B-Instruct
            sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, n=args.voting_n,
                                            max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)
        else:
            sampling_params = SamplingParams(temperature=1, top_p=0.9, n=args.voting_n,
                                            max_tokens=32768 if 'QwQ' in args.model_path else 8192, seed=42)


    if args.configs is None:
        args.configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']

    for config in args.configs:
        if not args.use_voting:
            output_dir = os.path.join(args.output_dir, args.model_name)
        else:
            output_dir = os.path.join(args.output_dir, f'{args.model_name}_voting')
        os.makedirs(output_dir, exist_ok=True)

        input_data = load_dataset('Qwen/ProcessBench', split=config)
        prompt_token_ids = [apply_chat_template(toker, prepare_input_boxed(TEMPLATE, e))
                            for e in input_data]

        generations = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)


        res_data = []
        for i in range(len(input_data)):
            d = input_data[i].copy()

            if not args.use_voting:
                generated_critique = generations[i].outputs[0].text
                pred = extract_answer(generated_critique)
                try:
                    pred = int(pred)
                except:
                    pred = None
            else:
                generated_critique = [ee.text for ee in generations[i].outputs]
                preds = [extract_answer(e) for e in generated_critique]
                preds = [e for e in preds if e is not None]
                if len(preds) == 0:
                    pred = None
                else:
                    pred = Counter(preds).most_common(1)[0][0]
                    try:
                        pred = int(pred)
                    except:
                        pred = None

            d['generated_critique'] = generated_critique
            d['prediction'] = pred
            d['match'] = (pred == d['label'])

            res_data.append(d)


        error_data = [e for e in res_data if e['label'] != -1]
        correct_data = [e for e in res_data if e['label'] == -1]

        with open(os.path.join(output_dir, f'{config}_error.jsonl'), 'w') as f:
            for e in error_data:
                f.write(json.dumps(e) + '\n')
        with open(os.path.join(output_dir, f'{config}_correct.jsonl'), 'w') as f:
            for e in correct_data:
                f.write(json.dumps(e) + '\n')
        
        acc1 = np.mean([e['match'] for e in error_data]) * 100
        acc2 = np.mean([e['match'] for e in correct_data]) * 100
        f1 = 2 * acc1 * acc2 / (acc1 + acc2)
        print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')


if __name__ == '__main__':
    main()