import trl
from trl import RewardTrainer
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments,
)

from utils import set_random_seed
from omegaconf import DictConfig, OmegaConf
import datetime

from utils import prepare_ds, prepare_model_tokenizer
import os
import hydra
import wandb
from data.rm_collator import RMAnthropicDataCollator


@hydra.main(config_path="configs", config_name="train_rm")
def main(config: DictConfig) -> None:
    set_random_seed(config.seed)
    
    # wandb.api_key = config.wandb_api_key
    # wandb.login(key)

    wandb.init(project=config.project_name, 
               notes=OmegaConf.to_yaml(config))

    ds = prepare_ds(config)
    model, tokenizer = prepare_model_tokenizer(config, reward_model=True)

    dt_now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_dir = os.path.join(config.trainer.output_dir, dt_now)

    args = TrainingArguments(
        num_train_epochs=config.trainer.num_train_epochs,
        per_device_train_batch_size=config.trainer.per_device_train_batch_size,
        per_device_eval_batch_size=config.trainer.per_device_eval_batch_size,
        learning_rate=config.trainer.learning_rate,
        weight_decay=config.trainer.weight_decay,
        gradient_accumulation_steps=config.trainer.gradient_accumulation_steps,
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        remove_unused_columns=False,
        bf16=config.trainer.bf16,
    )


    trainer = RewardTrainer(
        model=model,
        args=args,
        tokenizer=tokenizer,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        data_collator=RMAnthropicDataCollator(tokenizer)
    )

    trainer.train()




if __name__ == "__main__":
    main()