import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model_loader import base_model_inference_aio
from src.helper import extract_math_answer
from src.data_prep import process_aime, process_gsm, process_mmlu, process_sudoku
from src.pipelines import run_math_inference, verify_from_response
from transformers import AutoTokenizer
import torch
from vllm import LLM
from datasets import DatasetDict
import argparse
import json

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--dataset_name", type=str, default="AIME")
    parser.add_argument("--output_dir", type=str, default="outputs")
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    parser.add_argument("--n_rounds", type=int, default=3)
    parser.add_argument("--stop", type=int, default=1000)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--prompt_dir", type=str, default="prompts/self_refine")
    return parser.parse_args()

def create_prompt_from_history(single_history, prompt):
    # this creates single prompt based on a single history for refinement
    prompt += "\n"
    for round in single_history:
        prompt += f"Problem: {round['question'].strip()}\n\n" if 'question' in round else ''
        prompt += f"Solution: {round['response'].strip()}\n\n" if 'response' in round else ''
        prompt += f"Analysis: {round['feedback'].strip()}\n\n" if 'feedback' in round else ''
    prompt = prompt.strip() + "\n\nSolution:"
    return prompt

def create_feedback_prompt(question, response, prompt):
    prompt += "\n"
    prompt += f"Problem: {question.strip()}\n\n"
    prompt += f"Solution: {response.strip()}\n\n"
    prompt += f"Analysis:"
    return prompt

def generate_feedback(model, history, max_new_tokens, temperature, prompt):
    # this first generates the batch of prompts based on all history, then we can generate the feedback
    # prepare the prompt
    prompts = [create_feedback_prompt(entry[0]['question'].strip(), entry[-1]['response'].strip(), prompt) for entry in history]
    print("========= for feedback =============")
    print(prompts[0])
    print("DEBUG")
    results = base_model_inference_aio(model, prompts, temperature=temperature, max_new_tokens=max_new_tokens, stop_tokens=["Problem:", "Solution:", "Analysis:"])
    return results

def generate_refine_response(model, history, max_new_tokens, temperature, prompt):
    prompts = [create_prompt_from_history(each, prompt) for each in history]
    print("========= for refine =============")
    print(prompts[0])
    print("DEBUG")
    results = base_model_inference_aio(model, prompts, temperature=temperature, max_new_tokens=max_new_tokens, stop_tokens=["Problem:", "Solution:", "Analysis:"])
    return results

if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    num_gpus = torch.cuda.device_count()
    model = LLM(model=args.model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=max(args.n_rounds*args.max_new_tokens, 20000))
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    # load dataset
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "AIME": process_aime,
        "MMLU": process_mmlu
    }
    if "SUDOKU" in args.dataset_name:
        num_prefilled = int(args.dataset_name.split("_")[-1])
        dataset = process_sudoku(num_prefilled=num_prefilled)
        args.dataset_name = "SUDOKU"
    else:
        dataset = DATASET_PROCESSORS.get(args.dataset_name, lambda: None)()      
    with open(f'prompts/base_model/inference_{args.dataset_name}.md', 'r') as f:
        initial_gen_prompt = f.read()
    with open(f'{args.prompt_dir}/{args.dataset_name}/feedback.md', 'r') as f:
        feedback_prompt = f.read()
    with open(f'{args.prompt_dir}/{args.dataset_name}/refine.md', 'r') as f:
        refine_prompt = f.read()
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'][:args.stop],
        'val': dataset_train_valid['test'][:args.stop],
        'test': dataset['test'][:args.stop]
    })
    _, train_full_record = run_math_inference(
        model,
        tokenizer,
        data=(full_dataset['test']['question'], full_dataset['test']['answer']),
        save_path=f'{args.output_dir}/temp',
        batch_size=500,
        max_new_tokens=args.max_new_tokens,
        COT=True,
        temperature=args.temperature,
        base_model_prompt=initial_gen_prompt
    )

    # now we have the initial responses y_0, we can start the self-refinement process
    history = []
    for row in train_full_record:
        history.append([{
            'question': row['question'],
            'response': row['response'],
            'label': extract_math_answer(row['label'])
        }])
    for round in range(args.n_rounds):
        print(f"Round {round+1}")
        # we first generate the feedback
        feedbacks = generate_feedback(model, history, args.max_new_tokens, args.temperature, feedback_prompt)
        for entry, feedback in zip(history, feedbacks):
            entry[-1]['feedback'] = feedback
        refine_response = generate_refine_response(model, history, args.max_new_tokens, args.temperature, refine_prompt)
        for entry, response in zip(history, refine_response):
            entry.append({
                'response': response,
            })
    # save the history
    with open(f'{args.output_dir}/history.json', 'w') as f:
        json.dump(history, f)
    
    # analyze the history
    correct = 0
    oracle_correct = 0
    correct_in_first_round = 0
    for entry in history:
        label = entry[0]['label']
        solve_in_the_middle = False
        for round_idx, round_dict in enumerate(entry):
            if extract_math_answer(round_dict['response']) == label:
                oracle_correct += 1
                break
        if extract_math_answer(entry[0]['response']) == label:
            correct_in_first_round += 1
        for round_idx, round_dict in enumerate(entry):
            # Skip last round (it might not have feedback)
            if round_idx == len(entry) - 1:
                break
            # Check if this round has positive feedback
            if 'feedback' in round_dict and verify_from_response(round_dict['feedback']):
                # Use THIS round's response (the one that got positive feedback)
                if extract_math_answer(round_dict['response']) == label:
                    correct += 1
                solve_in_the_middle = True
                break
        # If not solved in middle, check last response
        if not solve_in_the_middle:
            if extract_math_answer(entry[-1]['response']) == label:
                correct += 1
    print(f"Accuracy: {correct/len(history)}")
    print(f"Oracle Accuracy: {oracle_correct/len(history)}")
    print(f"Accuracy in first round: {correct_in_first_round/len(history)}")
    print(f"Delta: {correct/len(history) - correct_in_first_round/len(history)}")
    with open(f'{args.output_dir}/accuracy.txt', 'w') as f:
        f.write(f"Final Accuracy: {correct/len(history)}\n")
        f.write(f"Accuracy in first round: {correct_in_first_round/len(history)}\n")
        f.write(f"Delta: {correct/len(history) - correct_in_first_round/len(history)}\n")
        
