import argparse
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.trainer import get_scheduler
import torch.nn as nn
import torch.optim as optim
from openrlhf.utils import blending_datasets, MockStrategy
from openrlhf.datasets import ProbDataset
from openrlhf.trainer import DPOPTrainer, DPOPMTrainer, RewardModelPTrainer
from openrlhf.models import get_llm_for_sequence_regression
from torch.utils.data import DataLoader
import pandas as pd
from openrlhf.models import Actor
from datetime import datetime
# Load your dataset



device="cuda"

parser = argparse.ArgumentParser()
parser.add_argument("--save_path", type=str, default="./ckpt")
parser.add_argument("--save_steps", type=int, default=-1)
parser.add_argument("--logging_steps", type=int, default=1)
parser.add_argument("--eval_steps", type=int, default=-1)
parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_rm")
parser.add_argument("--max_ckpt_num", type=int, default=3)
parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
parser.add_argument("--load_checkpoint", action="store_true", default=False)

# DeepSpeed
parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer")
parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
parser.add_argument("--overlap_comm", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)

# Models
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--value_head_prefix", type=str, default="score")

# Context Parallel
parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size")
parser.add_argument(
    "--ring_head_stride",
    type=int,
    default=1,
    help="the number of heads to do ring attention each time. "
    "It should be a divisor of the number of heads. "
    "A larger value may results in faster training but will consume more memory.",
)

# LoRA
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--lora_rank", type=int, default=0)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--lora_dropout", type=float, default=0)
parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")

# RM training
parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
parser.add_argument("--compute_fp32_loss", action="store_true", default=False)
parser.add_argument("--margin_loss", action="store_true", default=False)
parser.add_argument("--learning_rate", type=float, default=9e-6)
parser.add_argument("--lr_warmup_ratio", type=float, default=0.03)
parser.add_argument("--micro_train_batch_size", type=int, default=1)
parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
parser.add_argument("--loss", type=str, default="sigmoid")
parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss")
parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")

# packing samples using Flash Attention2
parser.add_argument("--packing_samples", action="store_true", default=False)

# Custom dataset
parser.add_argument("--dataset", type=str, default='JSON_Preference')
parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets")
parser.add_argument("--prompt_key", type=str, default=None)
parser.add_argument("--chosen_key", type=str, default="JSON Preference 1")
parser.add_argument("--rejected_key", type=str, default="JSON Preference 2")
parser.add_argument("--input_template", type=str, default=None)
parser.add_argument(
    "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
)
parser.add_argument("--tokenizer_chat_template", type=str, default=None)
parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset")
parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset")
parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples")
parser.add_argument("--max_len", type=int, default=512)

# wandb parameters
parser.add_argument("--use_wandb", type=str, default=None)
parser.add_argument("--wandb_org", type=str, default=None)
parser.add_argument("--wandb_group", type=str, default=None)
parser.add_argument("--wandb_project", type=str, default="openrlhf_train_rm")
parser.add_argument(
    "--wandb_run_name",
    type=str,
    default="rm_%s" % datetime.now().strftime("%m%dT%H:%M"),
)

# TensorBoard parameters
parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")

args = parser.parse_args()


model_name = '/data/checkpoint/gemma3-1b-rm'
# Load tokenizer for LLaMA
tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-1b-pt', use_fast=True, trust_remote_code=True)
# Use eos_token as pad_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"""
tokenizer.chat_template = template

strategy = MockStrategy(args)

train_data, eval_data = blending_datasets(
        args.dataset,
        args.dataset_probs,
        strategy,
        args.seed,
        max_count=args.max_samples,
        stopping_strategy="all_exhausted",
        train_split=args.train_split,
        eval_split=args.eval_split,
    )
train_data = train_data.select(range(min(args.max_samples, len(train_data))))
eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
# train_dataset = ProbDataset(
#     train_data,
#     tokenizer,
#     args.max_len,
#     strategy,
#     input_template=args.input_template,
#     is_dpo=True,
# )
eval_dataset = ProbDataset(
    eval_data,
    tokenizer,
    args.max_len,
    strategy,
    input_template=args.input_template,
    is_dpo=True,
)
    



# Load the LLaMA model
model = get_llm_for_sequence_regression(
        model_name,
        "reward",
        use_flash_attention_2=args.flash_attn,
        bf16=args.bf16,
        load_in_4bit=args.load_in_4bit,
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        target_modules=args.target_modules,
        lora_dropout=args.lora_dropout,
        init_value_head=True,
        value_head_prefix=args.value_head_prefix,
        packing_samples=args.packing_samples,
    ).to("cuda")




# ref_model = Actor(
#     'google/gemma-3-1b-pt',  # Replace with the actual LLaMA model path
#     use_flash_attention_2=args.flash_attn,
#     bf16=args.bf16,
#     load_in_4bit=args.load_in_4bit,                    # Use LoRA fine-tuning for efficiency
# ).to(device)

# Create DataLoaders for train and evaluation datasets
# train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=args.train_batch_size, collate_fn=eval_dataset.collate_fn)

# Define the optimizer
# optimizer = torch.optim.AdamW(actor_model.parameters(), lr=1e-4)

# scheduler = get_scheduler(
#         "cosine_with_min_lr",
#         optimizer,
#         num_warmup_steps=100,
#         num_training_steps=10000,
#         scheduler_specific_kwargs={"min_lr": 1e-5},
#     )




# Configure the trainer
trainer = RewardModelPTrainer(
        model=model,
        strategy=strategy,
        optim=optim,
        tokenizer=tokenizer,
        train_dataloader=None,
        eval_dataloader=eval_dataloader,
        scheduler=None,
        max_norm=args.max_norm,
        max_epochs=args.max_epochs,
        loss=args.loss,
    )
# consumed_samples = 0
# num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size
# trainer.fit(args, consumed_samples, num_update_steps_per_epoch)

# strategy.save_model(actor_model, output_dir="./checkpoint/llama3-1b-pm")
trainer.evaluate(eval_dataloader)


