import argparse
import gc
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Third-party imports
# from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import GRPOConfig, GRPOTrainer
import numpy as np
from openai import OpenAI
# Local imports
from src.data_prep import make_feedback_dataset
from src.helper import extract_math_answer, extract_last_question_and_solution, verify_from_response

# Environment settings
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["WANDB_LOG_MODEL"] = "false"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stop', type=int, default=9999999)
    parser.add_argument('--base_model_name', type=str, default=None) # model name for training
    parser.add_argument('--train_model_name', type=str, default=None) # model name for training
    parser.add_argument('--output_dir', type=str, default='output') # output directory
    parser.add_argument('--n_epoch', type=int, default=5) # number of epochs
    parser.add_argument('--learning_rate', type=float, default=1e-5) # learning rate
    parser.add_argument('--batch_size', type=int, default=6) # batch size
    parser.add_argument('--gradient_accumulation_steps', type=int, default=4) # gradient accumulation steps
    parser.add_argument('--num_generations', type=int, default=8) # number of generations
    parser.add_argument('--temperature', type=float, default=1.0) # temperature
    parser.add_argument('--max_new_tokens', type=int, default=2000) # max new tokens
    parser.add_argument('--max_seq_length', type=int, default=2048) # max sequence length
    parser.add_argument('--how_many_checkpoints', type=int, default=10) # how many checkpoints to save
    parser.add_argument('--reward_to_use', type=str, default='siv') # reward function to use
    parser.add_argument('--beta', type=float, default=0.0) # beta for the reward function
    parser.add_argument('--target_model_name', type=str, default='None') # target model name
    parser.add_argument('--prompt_dir', type=str, default='prompts/feedback') # prompt directory
    parser.add_argument('--debug', action='store_true', default=False) # whether to debug
    parser.add_argument('--checkpoint', type=str, default=None) # checkpoint to resume from
    parser.add_argument('--port', type=int, default=8070) # port for vllm server
    return parser.parse_args()

def get_from_server(model_name, prompts, max_new_tokens, temperature):
    client = OpenAI(
        api_key="EMPTY",  # vLLM doesn't need a real key
        base_url=f"http://localhost:{int(args.port)}/v1"
    )
    completion = client.completions.create(
    model=model_name,
    prompt=prompts,
    max_tokens=max_new_tokens,
    temperature=temperature,
    stop=["Problem", "Solution", "Analysis", "Feedback Received", "Original Problem", "Revised Solution", "Instruction"]
    )
    results = [choice.text.strip() for choice in completion.choices]
    return results

def vanilla_second_round_correctness_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if second attempt is correct conditioned on the feedback
    # 0 otherwise
    batched_prompts = [
            refine_prompt.format(question=extract_last_question_and_solution(prompt)[0], initial_response=extract_last_question_and_solution(prompt)[1], feedback=completion)
            for prompt, completion in zip(prompts, completions)
    ]
    batch_responses = get_from_server(args.base_model_name, prompts, max_new_tokens=512)

    rewards = np.array([1.0 if extract_math_answer(response) == extract_math_answer(gt) else 0.0 for response, gt in zip(batch_responses, ground_truth)])
    return rewards

def siv_second_round_correctness_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initail attempt is wrong, feedback indicates it as wrong, then with the feedback, second attempt is correct
    # 0 otherwise
    # Extract initial responses and check correctness
    initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    
    # Determine which cases need second round inference
    rewards = []
    batched_prompts = []
    indices_for_generation = []
    
    for i, (prompt, completion, initial_ans, gt_ans) in enumerate(zip(prompts, completions, initial_answers, gt_answers)):
        initial_correct = (initial_ans == gt_ans)
        feedback_says_correct = verify_from_response(completion)
        
        if initial_correct and feedback_says_correct:
            rewards.append(0.0)
        elif not initial_correct and not feedback_says_correct:
            # Need second round inference
            rewards.append(None)
            indices_for_generation.append(i)
            batched_prompts.append(refine_prompt.format(
                question=extract_last_question_and_solution(prompt)[0], # question
                initial_response=extract_last_question_and_solution(prompt)[1], # initial response
                feedback=completion # feedback
            ))
        else:
            # Feedback made an error
            rewards.append(0.0)
    
    # Run second round inference if needed
    if batched_prompts:
        if args.debug:
            print(f"DEBUG: Sending {len(batched_prompts)} prompts for inference")
            print(f"DEBUG: indices_for_generation = {indices_for_generation}")
            print("prompts used for refine (only for initial wrong + feedback wrong):================================================")
            print(batched_prompts[0])
            print('-'*100)
        batch_responses = get_from_server(args.base_model_name, batched_prompts, max_new_tokens=512, temperature=0)
        if args.debug:
            print(f"DEBUG: Received {len(batch_responses)} responses")
            print("++++++++++++ second round attempt +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
            print(batch_responses[0])
        # Fill in rewards
        for idx, gen_idx in enumerate(indices_for_generation):
            response_answer = extract_math_answer(batch_responses[idx])
            rewards[gen_idx] = 1.0 if response_answer == gt_answers[gen_idx] else 0.0
    return np.array(rewards)

def verifier_accuracy_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initial attempt is correct and verifier thinks it is correct, or initial attempt is wrong and verifier thinks it is wrong
    # 0 otherwise
    # Extract initial responses and ground truth answers once
    initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    if args.debug:
        print('<>'*100)
        print("initial responses:================================================")
        print(initial_responses[0])
        print("initial answers:================================================")
        print(initial_answers[0])
        print("gt answers:================================================")
        print(gt_answers[0])
        print("completions:================================================")
        print(completions[0])
        print('<>'*100)
    # Use list comprehension for better performance
    results = [
        1 if (initial_answer == gt_answer and verify_from_response(feedback)) or 
            (initial_answer != gt_answer and not verify_from_response(feedback))
        else 0
        for initial_answer, gt_answer, feedback in zip(initial_answers, gt_answers, completions)
    ]
    
    return np.array(results)

def strict_verifier_accuracy_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initial attempt is correct and verifier thinks it is correct, or initial attempt is wrong and verifier thinks it is wrong
    # 0 otherwise
    # Extract initial responses and ground truth answers once
    initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    if args.debug:
        print('<>'*100)
        print("initial responses:================================================")
        print(initial_responses[0])
        print("initial answers:================================================")
        print(initial_answers[0])
        print("gt answers:================================================")
        print(gt_answers[0])
        print("completions:================================================")
        print(completions[0])
        print('<>'*100)
    # Use list comprehension for better performance
    results = [
        1 if (initial_answer == gt_answer and "there is no error" in feedback.lower() and  "there is at least one error" not in feedback.lower()) or 
            (initial_answer != gt_answer and "there is at least one error" in feedback.lower() and "there is no error" not in feedback.lower()) # correct verification
        else 0 # no conclusion or wrong verification
        for initial_answer, gt_answer, feedback in zip(initial_answers, gt_answers, completions)
    ]
    
    return np.array(results)

def false_positive_verifier_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initial attempt is correct and verifier thinks it is correct, or initial attempt is wrong and verifier thinks it is wrong
    # 0 otherwise
    # Extract initial responses and ground truth answers once
    initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    if args.debug:
        print('<>'*100)
        print("initial responses:================================================")
        print(initial_responses[0])
        print("initial answers:================================================")
        print(initial_answers[0])
        print("gt answers:================================================")
        print(gt_answers[0])
        print("completions:================================================")
        print(completions[0])
        print('<>'*100)
    # Use list comprehension for better performance
    results = [
        1 if (initial_answer == gt_answer and "there is no error" in feedback.lower() and  "there is at least one error" not in feedback.lower()) or 
            (initial_answer != gt_answer and "there is at least one error" in feedback.lower() and "there is no error" not in feedback.lower()) # correct verification
        else -1 if (initial_answer == gt_answer and "there is at least one error" in feedback.lower() and "there is no error" not in feedback.lower()) # false positive
        else 0 # no conclusion or wrong verification
        for initial_answer, gt_answer, feedback in zip(initial_answers, gt_answers, completions)
    ]
    
    return np.array(results)

def feedback_format_reward_func(completions, **kwargs):
    def is_correct_format(feedback):
        if "There is no error" or "there is an error" in feedback.lower():
            return True
        else:
            return False
    # 1 if the feedback is in the correct format
    # 0 otherwise
    return np.array([1.0 if is_correct_format(completion) else 0.0 for completion in completions])

def get_reward(reward_names, weight=None):
    # Handle both string and list input
    if isinstance(reward_names, str):
        reward_names = reward_names.split()  # Split space-separated string
    reward_funcs = []
    for name in reward_names:
        if name not in reward_dict:
            raise ValueError(f"Unknown reward function: {name}. "
                           f"Available: {list(reward_dict.keys())}")
        reward_funcs.append(reward_dict[name])
    return reward_funcs

if __name__ == "__main__":
    args = parse_args()
    reward_dict = {
        'vanilla': vanilla_second_round_correctness_reward_func,
        'siv': siv_second_round_correctness_reward_func,
        'verifier': verifier_accuracy_reward_func,
        'verifier_strict': strict_verifier_accuracy_reward_func,
        'verifier_penalty': false_positive_verifier_reward_func,
    }
    if not args.train_model_name:
        args.train_model_name = args.base_model_name
    os.makedirs(args.output_dir, exist_ok=True)
    with open(f'{args.output_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    set_seed(42)
    # set the reward function here
    reward_func = get_reward(args.reward_to_use)
    print("debug reward")
    print(reward_func)
    
    os.environ["WANDB_PROJECT"] = "Feedback"  # name your W&B project
    # initialize model_f
    model_f = AutoModelForCausalLM.from_pretrained(args.train_model_name, torch_dtype=torch.bfloat16)
    tokenizer_f = AutoTokenizer.from_pretrained(args.train_model_name)
    with open(f'{args.prompt_dir}/feedback.md', 'r') as f:
        feedback_prompt = f.read()
    with open(f'{args.prompt_dir}/refine.md', 'r') as f:
        refine_prompt = f.read()
    tokenizer_0 = tokenizer_f
    feedback_hf_dataset = make_feedback_dataset(f"{args.output_dir}/star_0/initial_run", feedback_prompt, tokenizer_0)
    print("feedback_hf_dataset:================================================")
    for d in feedback_hf_dataset['prompt'][:5]:
        print(d)
        print('-'*100)
    num_cpu = int(os.environ.get('SLURM_CPUS_PER_TASK')) if os.environ.get('SLURM_CPUS_PER_TASK') else 1
    print(f'num_cpu: {num_cpu}')
    grpo_config = GRPOConfig(
        # peft_config=lora_config,
        # eval_strategy='steps',
        # eval_steps=0.1,
        # dataloader_num_workers=1,
        # dataloader_pin_memory=True,
        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.3,  # Max safe allocation
        generation_kwargs={"stop": ["Problem", "Solution", "Analysis", "Instruction"]},
        optim="paged_adamw_8bit",
        bf16=True,
        fp16=False,
        num_generations=args.num_generations,
        output_dir=f"{args.output_dir}/model_i_checkpoints",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        do_eval=False,
        logging_steps=5,
        logging_strategy='steps',
        seed=42,
        num_train_epochs=args.n_epoch,
        logging_first_step=True,
        lr_scheduler_type='constant_with_warmup',
        report_to='wandb',
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        #save_strategy='epoch',
        save_steps=float(1/args.how_many_checkpoints),
        max_prompt_length=args.max_seq_length,
        max_completion_length=args.max_new_tokens,
        temperature=args.temperature,
        remove_unused_columns=False,
        beta=args.beta
    )

    trainer = GRPOTrainer(
        model=model_f,
        train_dataset=feedback_hf_dataset,
        args=grpo_config,
        reward_funcs=reward_func
    )

    trainer.train(resume_from_checkpoint=args.checkpoint)
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

