import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import torch
from datasets import DatasetDict
from src.data_prep import process_gsm, process_aime, process_mmlu, process_sudoku
from src.helper import extract_math_answer
from src.model_loader import base_model_inference_aio
from vllm import LLM
from transformers import set_seed, AutoTokenizer

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--initial_model_name', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
    parser.add_argument('--verifier_model_name', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
    parser.add_argument('--dataset_name', type=str, default='GSM')
    parser.add_argument('--stop', type=int, default=9999999)
    parser.add_argument('--max_new_tokens', type=int, default=2000)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--prompt_dir', type=str, default='prompts')
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--temperature', type=float, default=0.0)
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    # load data
    with open(f'{args.output_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    set_seed(42)
    num_gpus = 1
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "AIME": process_aime,
        "MMLU": process_mmlu,
    }

    # Get the appropriate processor function and execute it
    if "SUDOKU" in args.dataset_name:
        num_prefilled = int(args.dataset_name.split("_")[-1]) if "_" in args.dataset_name else 10
        dataset = process_sudoku(num_prefilled=num_prefilled)
        args.dataset_name = "SUDOKU"
    else:
        dataset = DATASET_PROCESSORS.get(args.dataset_name, lambda: None)()      
    with open(f'{args.prompt_dir}/feedback_{args.dataset_name}/verify.md', 'r') as f:
        feedback_prompt = f.read()
    with open(f'{args.prompt_dir}/feedback_{args.dataset_name}/binary_refine.md', 'r') as f:
        refine_prompt = f.read()
    with open(f'prompts/base_model/inference_{args.dataset_name}.md', 'r') as f:
        base_initial_prompt = f.read()
    args.feedback_prompt = feedback_prompt
    args.refine_prompt = refine_prompt
    print(dataset["train"][0])
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'][:args.stop],
        'val': dataset_train_valid['test'][:args.stop],
        'test': dataset['test'][:args.stop]
    })
    print(f'we have {len(full_dataset["train"])} training data, {len(full_dataset["val"])} validation data and {len(full_dataset["test"])} testing data')

    # load initial model
    model = LLM(model=args.initial_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
    tokenizer = AutoTokenizer.from_pretrained(args.initial_model_name)

    # first pass
    prompts = []
    for question in full_dataset['test']['question']:
        prompts.append(base_initial_prompt.format(question=question.strip()))
    print("DEBUG INTIIAL RUN")
    print(prompts[0])
    responses = base_model_inference_aio(model, prompts, temperature=args.temperature, max_new_tokens=args.max_new_tokens)
    initial_record = []
    for response, question, label in zip(responses, full_dataset['test']['question'], full_dataset['test']['answer']):
        initial_record.append({
            'question': question,
            'answer': extract_math_answer(response),
            'response': response,
            'label': extract_math_answer(label),
        })
    if args.verifier_model_name != args.initial_model_name:
        # finished intial run delete models
        del model
        del tokenizer
        torch.cuda.empty_cache()
        model = LLM(model=args.verifier_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
        tokenizer = AutoTokenizer.from_pretrained(args.verifier_model_name)
    prompts = []
    print("DEBUG initial_record:", type(initial_record), len(initial_record))
    print(initial_record[0])
    for entry in initial_record:
        prompts.append(feedback_prompt.format(question=entry['question'].strip(), initial_response=entry['response'].strip()))
    print("DEBUG FEEDBACK PROMPT")
    print(prompts[0])
    responses = base_model_inference_aio(model, prompts, temperature=args.temperature, max_new_tokens=args.max_new_tokens, stop_tokens=["Problem:", "Solution:", "Analysis:", "Feedback Received:", "Previous Solution:", "Original Problem:", "Instruction:", "Revised Solution:"])
    for entry, response in zip(initial_record, responses):
        entry['feedback'] = response
    if args.verifier_model_name != args.initial_model_name:
        # finished intial run delete models
        del model
        del tokenizer
        torch.cuda.empty_cache()
        model = LLM(model=args.initial_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
        tokenizer = AutoTokenizer.from_pretrained(args.initial_model_name)
    prompts = []
    for entry in initial_record:
        prompts.append(refine_prompt.format(question=entry['question'].strip(), initial_response=entry['response'].strip()))
    responses = base_model_inference_aio(model, prompts, temperature=args.temperature, max_new_tokens=args.max_new_tokens, stop_tokens=["Problem:", "Solution:", "Analysis:", "Feedback Received:", "Previous Solution:", "Original Problem:", "Instruction:", "Revised Solution:"])
    for entry, response in zip(initial_record, responses):
        entry['refined_response'] = response
    print("now evaluating")
    initial_correct = 0
    correct = 0
    oracle_correct = 0
    verified = 0
    for entry in initial_record:
        if entry['answer'] == entry['label'] or extract_math_answer(entry['refined_response']) == entry['label']:
            oracle_correct += 1
        if entry['answer'] == entry['label'] and (extract_math_answer(entry['feedback']) == 1 or "there is no error" in entry['feedback'].lower()):
            correct += 1
            verified += 1
        if ("there is no error" not in entry['feedback'].lower() or extract_math_answer(entry['feedback']) == 0) and extract_math_answer(entry['refined_response']) == entry['label'] and entry['answer'] != entry['label']:
            correct += 1
        if entry['answer'] != entry['label'] and extract_math_answer(entry['feedback']) == 0:
            verified += 1
        if entry['answer'] == entry['label']:
            initial_correct += 1
    print(f"initial accuracy: {initial_correct/len(initial_record)}")
    print(f"oracle accuracy: {oracle_correct/len(initial_record)}")
    print(f"final accuracy: {correct/len(initial_record)}")
    print(f"good verified rate: {verified/len(initial_record)}")
    with open(f'{args.output_dir}/full_record.json', 'w') as f:
        json.dump(initial_record, f, indent=4)
    with open(f'{args.output_dir}/evaluation.txt', 'w') as f:
        f.write(f"initial accuracy: {initial_correct/len(initial_record)}\n")
        f.write(f"oracle accuracy: {oracle_correct/len(initial_record)}\n")
        f.write(f"final accuracy: {correct/len(initial_record)}\n")
        f.write(f"good verified rate: {verified/len(initial_record)}\n")