import argparse
import gc
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Third-party imports
# from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import GRPOConfig, GRPOTrainer
import numpy as np
# Local imports
from src.data_prep import make_baseline_dataset
from src.helper import extract_math_answer

# Environment settings
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["WANDB_LOG_MODEL"] = "false"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stop', type=int, default=9999999)
    parser.add_argument('--dataset_name', type=str, default='GSM') # dataset name
    parser.add_argument('--base_model_name', type=str, default=None) # model name for training
    parser.add_argument('--output_dir', type=str, default='output') # output directory
    parser.add_argument('--n_epoch', type=int, default=5) # number of epochs
    parser.add_argument('--learning_rate', type=float, default=1e-5) # learning rate
    parser.add_argument('--batch_size', type=int, default=6) # batch size
    parser.add_argument('--gradient_accumulation_steps', type=int, default=4) # gradient accumulation steps
    parser.add_argument('--num_generations', type=int, default=8) # number of generations
    parser.add_argument('--temperature', type=float, default=1.0) # temperature
    parser.add_argument('--max_new_tokens', type=int, default=2000) # max new tokens
    parser.add_argument('--max_seq_length', type=int, default=2048) # max sequence length
    parser.add_argument('--how_many_checkpoints', type=int, default=10) # how many checkpoints to save
    parser.add_argument('--checkpoint', type=str, default=None) # checkpoint to resume from
    parser.add_argument('--beta', type=float, default=0.0) # beta for the reward function
    return parser.parse_args()

def correctness_reward(completions, ground_truth, **kwargs):
    # 1 if attempt is correct conditioned on the feedback
    # 0 otherwise
    rewards = np.array([1.0 if extract_math_answer(response) == extract_math_answer(gt) else 0.0 for response, gt in zip(completions, ground_truth)])
    return rewards

if __name__ == "__main__":
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    with open(f'{args.output_dir}/args.json', 'w') as f:
        json.dump(vars(args), f)
    set_seed(42)    
    os.environ["WANDB_PROJECT"] = "Feedback"  # name your W&B project
    # initialize model_f
    model_f = AutoModelForCausalLM.from_pretrained(args.base_model_name, torch_dtype=torch.bfloat16)
    tokenizer_f = AutoTokenizer.from_pretrained(args.base_model_name)
    if "SUDOKU_" in args.dataset_name:
        with open(f'prompts/base_model/inference_{args.dataset_name.split("_")[0]}.md', 'r') as f:
            inference_prompt = f.read()
    else:    
        with open(f'prompts/base_model/inference_{args.dataset_name}.md', 'r') as f:
            inference_prompt = f.read()
    tokenizer_0 = tokenizer_f
    baseline_hf_dataset = make_baseline_dataset(inference_prompt, args.dataset_name)
    print("baseline_hf_dataset:================================================")
    for d in baseline_hf_dataset['prompt'][:5]:
        print(d)
        print('-'*100)
    num_cpu = int(os.environ.get('SLURM_CPUS_PER_TASK')) if os.environ.get('SLURM_CPUS_PER_TASK') else 1
    print(f'num_cpu: {num_cpu}')
    grpo_config = GRPOConfig(
        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.3,  # Max safe allocation
        generation_kwargs={"stop": ["Problem", "Solution", "Analysis", "Instruction"]},
        optim="paged_adamw_8bit",
        bf16=True,
        fp16=False,
        num_generations=args.num_generations,
        output_dir=f"{args.output_dir}/model_i_checkpoints",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        do_eval=False,
        logging_steps=5,
        logging_strategy='steps',
        seed=42,
        num_train_epochs=args.n_epoch,
        logging_first_step=True,
        lr_scheduler_type='constant_with_warmup',
        report_to='wandb',
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        #save_strategy='epoch',
        save_steps=float(1/args.how_many_checkpoints),
        max_prompt_length=args.max_seq_length,
        max_completion_length=args.max_new_tokens,
        temperature=args.temperature,
        remove_unused_columns=False,
        beta=args.beta
    )

    trainer = GRPOTrainer(
        model=model_f,
        train_dataset=baseline_hf_dataset,
        args=grpo_config,
        reward_funcs=correctness_reward
    )

    trainer.train(resume_from_checkpoint=args.checkpoint)
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

