import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import torch
from datasets import DatasetDict
from src.data_prep import process_gsm, process_aime, process_mmlu, process_med, process_sudoku, process_gsm_symbolic
from src.pipelines import run_math_inference, run_feedback_inference, run_math_inference_with_feedback
from src.helper import evaluate_record
from vllm import LLM
from transformers import set_seed, AutoTokenizer

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--initial_model_name', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
    parser.add_argument('--verifier_model_name', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
    parser.add_argument('--dataset_name', type=str, default='GSM')
    parser.add_argument('--med_subset', type=str, default='Surgery')
    parser.add_argument('--stop', type=int, default=9999999)
    parser.add_argument('--max_new_tokens', type=int, default=2000)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--prompt_dir', type=str, default='prompts')
    parser.add_argument('--run_feedback', action='store_true', default=False)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--skip_if_verified', action='store_true', default=False)
    parser.add_argument('--temperature', type=float, default=0.0)
    parser.add_argument('--sudoku_num_prefilled', type=int, default=10)
    parser.add_argument('--exclude_final_assessment', action='store_true', default=False)
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    if "Instruct" not in args.initial_model_name:
        base_model_prompt_path = f'prompts/base_model/inference_{args.dataset_name}.md'
        with open(base_model_prompt_path, 'r') as f:
            base_model_prompt = f.read()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    with open(f'{args.prompt_dir}/feedback.md', 'r') as f:
        feedback_prompt = f.read()
    with open(f'{args.prompt_dir}/refine.md', 'r') as f:
        refine_prompt = f.read()
    args.feedback_prompt = feedback_prompt
    args.refine_prompt = refine_prompt
    os.makedirs(args.output_dir, exist_ok=True)
    # load data
    with open(f'{args.output_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    set_seed(42)
    num_gpus = 1
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "GSM-S": process_gsm_symbolic,
        "AIME": process_aime,
        "MMLU": process_mmlu,
    }

    # Get the appropriate processor function and execute it
    if args.dataset_name == "MED":
        dataset = process_med(args.med_subset)
    elif args.dataset_name == "SUDOKU":
        dataset = process_sudoku(num_prefilled=args.sudoku_num_prefilled)
    else:
        dataset = DATASET_PROCESSORS.get(args.dataset_name, lambda: None)()      
    print(dataset["train"][0])
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'][:args.stop],
        'val': dataset_train_valid['test'][:args.stop],
        'test': dataset['test'][:args.stop]
    })
    print(f'we have {len(full_dataset["train"])} training data, {len(full_dataset["val"])} validation data and {len(full_dataset["test"])} testing data')

    # load initial model
    model = LLM(model=args.initial_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
    tokenizer = AutoTokenizer.from_pretrained(args.initial_model_name)

    # first pass
    initial_submission, initial_full_record = run_math_inference(
        model,
        tokenizer,
        data=(full_dataset[args.split]['question'], full_dataset[args.split]['answer']),
        save_path=f'{args.output_dir}/initial_run',
        batch_size=42424242,
        max_new_tokens=args.max_new_tokens,
        COT=True,
        base_model_prompt=base_model_prompt
    )

    eval_df = evaluate_record(initial_full_record)
    eval_df.to_csv(f'{args.output_dir}/initial_run/evaluation.csv')
    if args.run_feedback:
        # now use the feedback provider
        if args.verifier_model_name != 'None' and args.verifier_model_name != args.initial_model_name:
            del model
            del tokenizer
            torch.cuda.empty_cache()
            model = LLM(args.verifier_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
            tokenizer = AutoTokenizer.from_pretrained(args.verifier_model_name)
            print('finished initial run, deleted initial model for memory')
        # now run the feedback inference
        retrospection_full_record = run_feedback_inference(
            model,
            tokenizer,
            full_record_path=f'{args.output_dir}/initial_run',
            save_path=f'{args.output_dir}/feedback_run',
            batch_size=42424242,
            max_new_tokens=args.max_new_tokens,
            feedback_prompt=feedback_prompt,
            base_model_prompt=feedback_prompt,
            temperature=args.temperature
        )
        if args.verifier_model_name != 'None':
            del model
            del tokenizer
            torch.cuda.empty_cache()
            model = LLM(args.initial_model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=8*args.max_new_tokens)
            tokenizer = AutoTokenizer.from_pretrained(args.initial_model_name)
            print('finished feedback run, deleted feedback model for memory')
        
        final_submission, final_full_record = run_math_inference_with_feedback(
            model,
            tokenizer,
            data=(full_dataset[args.split]['question'], full_dataset[args.split]['answer']),
            full_record_path=f'{args.output_dir}/feedback_run',
            save_path=f'{args.output_dir}/final_run',
            batch_size=42424242,
            max_new_tokens=args.max_new_tokens,
            refine_prompt=refine_prompt,
            skip_if_verified=args.skip_if_verified,
            base_model_prompt=refine_prompt,
            exclude_final_assessment=args.exclude_final_assessment
        )
        eval_df = evaluate_record(final_full_record)
        eval_df.to_csv(f'{args.output_dir}/final_run/evaluation.csv')