import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model_loader import base_model_inference_aio
from src.helper import extract_math_answer
from src.data_prep import process_aime, process_gsm, process_mmlu, process_sudoku
from src.pipelines import run_math_inference
from transformers import AutoTokenizer
import torch
from vllm import LLM
from datasets import DatasetDict
import argparse
import json

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--dataset_name", type=str, default="AIME")
    parser.add_argument("--output_dir", type=str, default="outputs")
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    parser.add_argument("--n_rounds", type=int, default=3)
    parser.add_argument("--stop", type=int, default=1000)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--prompt_dir", type=str, default="prompts/self_refine")
    return parser.parse_args()

def create_reward_prompt(history, prompt):
    question = history[0]['question']
    response = history[-1]['response']
    prompt += "Problem: " + question.strip() + "\n\n"
    prompt += "Solution: " + response.strip() + "\n\n"
    prompt += "Judgement:"
    return prompt

def generate_reward(model, history, max_new_tokens, temperature, prompt):
    prompts = [create_reward_prompt(each, prompt) for each in history]
    print("========== for reward ==========")
    print(prompts[0])
    results = base_model_inference_aio(model, prompts, temperature=temperature, max_new_tokens=max_new_tokens, stop_tokens=["Problem:", "Solution:", "Judgement:", "Reflection:"])
    reward = [extract_math_answer(result.strip()) for result in results]
    return reward

def create_reflection_prompt(history, prompt):
    prompt += "\n\n"
    question = history[0]['question']
    response = history[-1]['response']
    _reward = history[-1]['reward']
    reward = "The above answer is correct." if _reward == 1 else "The above answer is incorrect."
    prompt += "Problem: " + question.strip() + "\n\n"
    prompt += "Solution: " + response.strip() + "\n\n"
    prompt += "Judgement: " + reward.strip() + "\n\n"
    prompt += "Reflection:"
    return prompt

def generate_reflection(model, history, max_new_tokens, temperature, prompt):
    prompts = [create_reflection_prompt(each, prompt) for each in history]
    print("========== for reflection ==========")
    print(prompts[0])
    results = base_model_inference_aio(model, prompts, temperature=temperature, max_new_tokens=max_new_tokens, stop_tokens=["Problem:", "Solution:", "Judgement:", "Reflection:"])
    return results

def create_refine_prompt(history, prompt):
    prompt += "\n"
    question = history[0]['question']
    memory = "\n".join([f"{entry['reflection']}" for entry in history])
    prompt += "Problem: " + question.strip() + "\n\n" + "Reflection: " + memory.strip() + "\n\n"
    prompt += "Solution:"
    return prompt

def generate_refine_response(model, history, max_new_tokens, temperature, prompt):
    prompts = [create_refine_prompt(each, prompt) for each in history]
    print("========== for refine ==========")
    print(prompts[0])
    results = base_model_inference_aio(model, prompts, temperature=temperature, max_new_tokens=max_new_tokens, stop_tokens=["Problem:", "Solution:", "Judgement:", "Reflection:"])
    return results
if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    args = parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    num_gpus = torch.cuda.device_count()
    model = LLM(model=args.model_name, dtype='half', tensor_parallel_size=num_gpus, max_model_len=max(args.n_rounds*args.max_new_tokens, 20000))
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # load dataset
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "AIME": process_aime,
        "MMLU": process_mmlu
    }
    if "SUDOKU" in args.dataset_name:
        num_prefilled = int(args.dataset_name.split("_")[-1])
        dataset = process_sudoku(num_prefilled=num_prefilled)
        args.dataset_name = "SUDOKU"
    else:
        dataset = DATASET_PROCESSORS.get(args.dataset_name, lambda: None)()      
    with open(f'prompts/base_model/inference_{args.dataset_name}.md', 'r') as f:
        initial_gen_prompt = f.read()
    with open(f'{args.prompt_dir}/{args.dataset_name}/reward.md', 'r') as f:
        reward_prompt = f.read()
    with open(f'{args.prompt_dir}/{args.dataset_name}/reflection.md', 'r') as f:
        reflection_prompt = f.read()
    with open(f'{args.prompt_dir}/{args.dataset_name}/refine.md', 'r') as f:
        refine_prompt = f.read()
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'][:args.stop],
        'val': dataset_train_valid['test'][:args.stop],
        'test': dataset['test'][:args.stop]
    })
    _, train_full_record = run_math_inference(
        model,
        tokenizer,
        data=(full_dataset['test']['question'], full_dataset['test']['answer']),
        save_path=f'{args.output_dir}/temp',
        batch_size=500,
        max_new_tokens=args.max_new_tokens,
        COT=True,
        temperature=args.temperature,
        base_model_prompt=initial_gen_prompt
    )

    history = []
    for row in train_full_record:
        history.append([{
            'question': row['question'],
            'response': row['response'],
            'label': extract_math_answer(row['label'])
        }])
    
    for round in range(args.n_rounds):
        print(f"Round {round+1}")
        # we first generate the feedback
        rewards = generate_reward(model, history, args.max_new_tokens, args.temperature, reward_prompt)
        for entry, reward in zip(history, rewards):
            entry[-1]['reward'] = reward
        reflections = generate_reflection(model, history, args.max_new_tokens, args.temperature, reflection_prompt)
        for entry, reflection in zip(history, reflections):
            entry[-1]['reflection'] = reflection
        refine_response = generate_refine_response(model, history, args.max_new_tokens, args.temperature, refine_prompt)
        for entry, response in zip(history, refine_response):
            entry.append({
                'response': response,
            })
    with open(f'{args.output_dir}/history.json', 'w') as f:
        json.dump(history, f)
    
    # analyze the history
    correct = 0
    correct_in_first_round = 0
    for entry in history:
        label = entry[0]['label']
        solve_in_the_middle = False
        if extract_math_answer(entry[0]['response']) == label:
            correct_in_first_round += 1
        for round_idx, round_dict in enumerate(entry):
            # Skip last round (it might not have feedback)
            if round_idx == len(entry) - 1:
                break
            # Check if this round has positive feedback
            if 'reward' in round_dict and round_dict['reward'] == 1:
                # Use THIS round's response (the one that got positive feedback)
                if extract_math_answer(round_dict['response']) == label:
                    correct += 1
                solve_in_the_middle = True
                break
        # If not solved in middle, check last response
        if not solve_in_the_middle:
            if extract_math_answer(entry[-1]['response']) == label:
                correct += 1
    print(f"Final Accuracy: {correct/len(history)}")
    print(f"Accuracy in first round: {correct_in_first_round/len(history)}")
    print(f"Delta: {correct/len(history) - correct_in_first_round/len(history)}")
    with open(f'{args.output_dir}/accuracy.txt', 'w') as f:
        f.write(f"Final Accuracy: {correct/len(history)}\n")
        f.write(f"Accuracy in first round: {correct_in_first_round/len(history)}\n")
        f.write(f"Delta: {correct/len(history) - correct_in_first_round/len(history)}\n")
    



