import argparse
import gc
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Third-party imports
# from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import GRPOConfig, GRPOTrainer
from vllm import LLM, SamplingParams
import numpy as np
import time
import threading
# Local imports
from src.data_prep import make_feedback_dataset
from src.helper import extract_math_answer, extract_last_question_and_solution, verify_from_response
from src.model_loader import model_inference_batch_vllm, model_inference_batch_vllm_aio

# Environment settings
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["WANDB_LOG_MODEL"] = "false"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stop', type=int, default=9999999)
    parser.add_argument('--base_model_name', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B') # model name for training
    parser.add_argument('--output_dir', type=str, default='output') # output directory
    parser.add_argument('--n_epoch', type=int, default=5) # number of epochs
    parser.add_argument('--learning_rate', type=float, default=1e-5) # learning rate
    parser.add_argument('--batch_size', type=int, default=6) # batch size
    parser.add_argument('--num_generations', type=int, default=8) # number of generations
    parser.add_argument('--temperature', type=float, default=1.0) # temperature
    parser.add_argument('--max_new_tokens', type=int, default=2000) # max new tokens
    parser.add_argument('--max_seq_length', type=int, default=2048) # max sequence length
    parser.add_argument('--lora_rank', type=int, default=32) # lora rank
    parser.add_argument('--how_many_checkpoints', type=int, default=10) # how many checkpoints to save
    parser.add_argument('--no_reference_model', action='store_true', default=False) # whether to use the reference model
    parser.add_argument('--reward_to_use', type=str, default='second_round_correctness') # reward function to use
    parser.add_argument('--beta', type=float, default=0.0) # beta for the reward function
    parser.add_argument('--target_model_name', type=str, default='None') # target model name
    parser.add_argument('--prompt_dir', type=str, default='prompts/feedback') # prompt directory
    parser.add_argument('--debug', action='store_true', default=False) # whether to debug
    return parser.parse_args()

class SimpleGPUKeepAlive:
    def __init__(self, model, tokenizer, interval=45):
        """
        Simple GPU keep-alive with stop/resume functionality
        
        Args:
            model: VLLM model instance
            tokenizer: Tokenizer instance
            interval: Seconds between keep-alive operations
        """
        self.model = model
        self.tokenizer = tokenizer
        self.interval = interval
        self.active = True  # Flag to control if keep-alive should run
        self.running = True  # Flag to control thread lifecycle
        self.thread = None
        
    def _keep_alive_loop(self):
        """Background thread that performs minimal GPU operations"""
        while self.running:
            if self.active:
                try:
                    # Minimal inference to keep GPU active
                    dummy_msg = ["Please write a long story about a cat."]*500
                    print("keep gpu running")
                    _ = self.model.generate(dummy_msg, sampling_params=SamplingParams(temperature=1, max_tokens=1000))
                except Exception as e:
                    print(f"Keep-alive error (continuing): {e}")
            
            time.sleep(self.interval)
    
    def start(self):
        print("starting gpu keeper alive")
        """Start the keep-alive thread"""
        self.thread = threading.Thread(target=self._keep_alive_loop, daemon=True)
        self.thread.start()
    
    def stop(self):
        print("stopping gpu keeper alive")
        """Temporarily stop keep-alive operations"""
        self.active = False
    
    def resume(self):
        print("resuming gpu keeper alive")
        """Resume keep-alive operations"""
        self.active = True
    
    def shutdown(self):
        """Permanently stop the thread"""
        self.running = False
        if self.thread:
            self.thread.join()

# Global keep-alive instance
gpu_keeper = None

def vanilla_second_round_correctness_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if second attempt is correct conditioned on the feedback
    # 0 otherwise
    global gpu_keeper
    if gpu_keeper:
        print("stopping gpu keeper alive")
        gpu_keeper.shutdown()
    batch_messages = [
            [{'role': 'user', 'content': refine_prompt.format(question=prompt.split("**Problem:**")[1].split("**Solution:**")[0].strip(), initial_response=prompt.split("**Solution:**")[1].split("<|im_end|>")[0].strip(), feedback=completion)}]
            for prompt, completion in zip(prompts, completions)
    ]
    if "Instruct" not in args.base_model_name:
        # we are dealing with a base model
        batch_messages = [
                [extract_last_question_and_solution(prompt)[0], extract_last_question_and_solution(prompt)[1], completion]
                for prompt, completion in zip(prompts, completions)
        ]
        base_refine_prompt = refine_prompt
    else:
        base_refine_prompt = None
    batch_responses = model_inference_batch_vllm(model_0, tokenizer_0, batch_messages, max_new_tokens=512, retro=False, prompt=base_refine_prompt)

    rewards = np.array([1.0 if extract_math_answer(response) == extract_math_answer(gt) else 0.0 for response, gt in zip(batch_responses, ground_truth)])
    print("restarting GPU keeper")
    gpu_keeper = SimpleGPUKeepAlive(model_0, tokenizer_0, interval=1)
    gpu_keeper.start()
    return rewards

def siv_second_round_correctness_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initail attempt is wrong, feedback indicates it as wrong, then with the feedback, second attempt is correct
    # 0 otherwise
    global gpu_keeper
    if gpu_keeper:
        print("stopping gpu keeper alive")
        gpu_keeper.shutdown()
    # Extract initial responses and check correctness
    if "Instruct" in args.base_model_name:
        initial_responses = [prompt.split("**Solution:**")[1].split("<|im_end|>")[0].strip() for prompt in prompts]
    else:
        initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    
    # Determine which cases need second round inference
    rewards = []
    batch_messages = []
    indices_for_generation = []
    
    for i, (prompt, completion, initial_ans, gt_ans) in enumerate(zip(prompts, completions, initial_answers, gt_answers)):
        initial_correct = (initial_ans == gt_ans)
        feedback_says_correct = verify_from_response(completion)
        
        if initial_correct and feedback_says_correct:
            rewards.append(0.0)
        elif not initial_correct and not feedback_says_correct:
            # Need second round inference
            rewards.append(None)
            indices_for_generation.append(i)
            if "Instruct" in args.base_model_name:
                batch_messages.append([{
                    'role': 'user',
                    'content': refine_prompt.format(
                        question=prompt.split("**Problem:**")[1].split("**Solution:**")[0].strip(),
                        initial_response=initial_responses[i],
                        feedback=completion
                    )
                }])
                base_refine_prompt = None
            else:
                batch_messages.append([
                    extract_last_question_and_solution(prompt)[0], # question
                    extract_last_question_and_solution(prompt)[1], # initial response
                    completion # feedback
                ])
                base_refine_prompt = refine_prompt
        else:
            # Feedback made an error
            rewards.append(0.0)
    
    # Run second round inference if needed
    if batch_messages:
        print(f"DEBUG: Sending {len(batch_messages)} messages for inference")
        print(f"DEBUG: indices_for_generation = {indices_for_generation}")
        print("Original prompt:================================================")
        print(prompts[0])
        print('-'*100)
        print("prompts (only for initial wrong + feedback wrong):================================================")
        print("++++++++++++ question +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print(batch_messages[0][0])
        print("++++++++++++ initial response +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print(batch_messages[0][1])
        print("++++++++++++ feedback +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print(batch_messages[0][2])
        print('-'*100)
        batch_responses = model_inference_batch_vllm_aio(model_0, tokenizer_0, batch_messages, max_new_tokens=512, retro=False, prompt=base_refine_prompt)
        # batch_responses_keep_gpu_running = model_inference_batch_vllm_aio(model_0, tokenizer_0, batch_messages*10, max_new_tokens=1024, retro=False, temperature=1)
        print(f"DEBUG: Received {len(batch_responses)} responses")
        #print(f"DEBUG: First response: {batch_responses[0] if batch_responses else 'EMPTY'}")

        print("responses:================================================")
        print("++++++++++++ second round attempt +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print(batch_responses[0])
        # Fill in rewards
        for idx, gen_idx in enumerate(indices_for_generation):
            response_answer = extract_math_answer(batch_responses[idx])
            rewards[gen_idx] = 1.0 if response_answer == gt_answers[gen_idx] else 0.0
    print("restarting GPU keeper")
    gpu_keeper = SimpleGPUKeepAlive(model_0, tokenizer_0, interval=1)
    gpu_keeper.start()
    return np.array(rewards)

def verifier_accuracy_reward_func(prompts, completions, ground_truth, **kwargs):
    # 1 if initial attempt is correct and verifier thinks it is correct, or initial attempt is wrong and verifier thinks it is wrong
    # 0 otherwise
    # Extract initial responses and ground truth answers once
    if "Instruct" in args.base_model_name:
        initial_responses = [prompt.split("**Solution:**")[1].split("<|im_end|>")[0].strip() for prompt in prompts]
    else:
        initial_responses = [extract_last_question_and_solution(prompt)[1] for prompt in prompts]
    initial_answers = [extract_math_answer(response) for response in initial_responses]
    gt_answers = [extract_math_answer(gt) for gt in ground_truth]
    print("initial responses:================================================")
    print(initial_responses[0])
    print("initial answers:================================================")
    print(initial_answers[0])
    print("gt answers:================================================")
    print(gt_answers[0])
    print("completions:================================================")
    print(completions[0])
    print('-'*100)
    # Use list comprehension for better performance
    results = [
        1 if (initial_answer == gt_answer and verify_from_response(feedback)) or 
            (initial_answer != gt_answer and not verify_from_response(feedback))
        else 0
        for initial_answer, gt_answer, feedback in zip(initial_answers, gt_answers, completions)
    ]
    
    return np.array(results)

def feedback_format_reward_func(completions, **kwargs):
    def is_correct_format(feedback):
        if "There is no error" or "there is an error" in feedback.lower():
            return True
        else:
            return False
    # 1 if the feedback is in the correct format
    # 0 otherwise
    return np.array([1.0 if is_correct_format(completion) else 0.0 for completion in completions])

if __name__ == "__main__":
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    with open(f'{args.output_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    set_seed(42)
    # set the reward function here
    if args.reward_to_use == 'vanilla':
        reward_func = vanilla_second_round_correctness_reward_func
    elif args.reward_to_use == 'siv':
        reward_func = [siv_second_round_correctness_reward_func, feedback_format_reward_func]
    elif args.reward_to_use == 'verifier_accuracy':
        reward_func = verifier_accuracy_reward_func
        print("since we are not using the second attempt, we will not use the reference model")
        args.no_reference_model = True
    elif args.reward_to_use == 'both':
        reward_func = [verifier_accuracy_reward_func, siv_second_round_correctness_reward_func]
    else:
        raise ValueError(f"Invalid reward function: {args.reward_to_use}")
    
    if not args.no_reference_model:
        # we load the reference model to get the second round of correctness reward
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"
        model_0 = LLM(model=args.base_model_name, dtype='half', max_model_len=2*args.max_seq_length, tensor_parallel_size=1)
        tokenizer_0 = AutoTokenizer.from_pretrained(args.base_model_name)
        gpu_keeper = SimpleGPUKeepAlive(model_0, tokenizer_0, interval=1)
        gpu_keeper.start()
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    num_gpus = torch.cuda.device_count()
    os.environ["WANDB_PROJECT"] = "Feedback"  # name your W&B project
    # initialize model_f
    model_f = AutoModelForCausalLM.from_pretrained(args.base_model_name, torch_dtype=torch.bfloat16)
    tokenizer_f = AutoTokenizer.from_pretrained(args.base_model_name)

    # lora_config = LoraConfig(
    #     r=args.lora_rank,
    #     lora_alpha=args.lora_rank * 2,
    #     lora_dropout=0.05,
    #     bias="none",
    #     target_modules='all-linear',
    #     task_type="CAUSAL_LM",
    # )
    with open(f'{args.prompt_dir}/feedback.md', 'r') as f:
        feedback_prompt = f.read()
    with open(f'{args.prompt_dir}/refine.md', 'r') as f:
        refine_prompt = f.read()
    tokenizer_0 = tokenizer_f
    feedback_hf_dataset = make_feedback_dataset(f"{args.output_dir}/star_0/initial_run", feedback_prompt, tokenizer_0)
    for d in feedback_hf_dataset['prompt'][:5]:
        print(d)
        print('-'*100)
    num_cpu = int(os.environ.get('SLURM_CPUS_PER_TASK')) if os.environ.get('SLURM_CPUS_PER_TASK') else 1
    print(f'num_cpu: {num_cpu}')
    print(f'num_gpu: {os.environ.get("SLURM_GPUS_PER_TASK")}')
    grpo_config = GRPOConfig(
        # peft_config=lora_config,
        # eval_strategy='steps',
        # eval_steps=0.1,
        # dataloader_num_workers=1,
        # dataloader_pin_memory=True,
        use_vllm=True,
        vllm_mode="server",
        vllm_server_host="localhost",
        vllm_server_port=8070,
        # vllm_gpu_memory_utilization=0.3,  # Max safe allocation
        generation_kwargs={"stop": ["Problem:", "Solution:", "Analysis:"]},
        optim="paged_adamw_8bit",
        bf16=True,
        fp16=False,
        num_generations=args.num_generations,
        output_dir=f"{args.output_dir}/model_i_checkpoints",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=4,
        do_eval=False,
        logging_steps=5,
        logging_strategy='steps',
        seed=42,
        num_train_epochs=args.n_epoch,
        logging_first_step=True,
        lr_scheduler_type='constant_with_warmup',
        report_to='wandb',
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        #save_strategy='epoch',
        save_steps=float(1/args.how_many_checkpoints),
        max_prompt_length=args.max_seq_length,
        max_completion_length=args.max_new_tokens,
        temperature=args.temperature,
        remove_unused_columns=False,
        beta=args.beta
    )

    trainer = GRPOTrainer(
        model=model_f,
        train_dataset=feedback_hf_dataset,
        args=grpo_config,
        reward_funcs=reward_func
    )

    trainer.train()
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

