# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
# using a Mistral 7B model.
#
# This config has been tested on an A100 80GB.
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/
#   tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --ignore-patterns "*.safetensors" --hf-token <HF_TOKEN>
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
#   pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
#   tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
#   tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#

output_dir: /tmp/torchtune/mistral_7B/full_ppo_low_memory # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
  _component_: torchtune.models.mistral.mistral_tokenizer
  path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model
  max_seq_len: 512

# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  source: trl-internal-testing/sentiment-trl-style
  split: train
  column: prompt
  add_eos: False

policy_model:
  _component_: torchtune.models.mistral.mistral_7b

# we need to manually build the mistral classifier model
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token)
reward_and_value_model:
  _component_: torchtune.modules.classifier_model
  base_model_path: torchtune.models.mistral.mistral_7b
  num_classes: 1
  vocab_size: 32001
  num_layers: 32
  num_heads: 32
  num_kv_heads: 8
  attn_dropout: 0.0
  embed_dim: 4096
  intermediate_dim: 14336
  max_seq_len: 32768
  norm_eps: 1.0e-05

# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
  checkpoint_files:
    [
      "pytorch_model-00001-of-00003.bin",
      "pytorch_model-00002-of-00003.bin",
      "pytorch_model-00003-of-00003.bin",
    ]
  # this is the only place where you should update `recipe_checkpoint` if resuming training
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: MISTRAL

# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
  checkpoint_files:
    [
      "pytorch_model-00001-of-00003.bin",
      "pytorch_model-00002-of-00003.bin",
      "pytorch_model-00003-of-00003.bin",
    ]
  output_dir: ${output_dir}/policy
  model_type: MISTRAL

# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/RM-Mistral-7B/
  checkpoint_files:
    [
      "model-00001-of-00003.safetensors",
      "model-00002-of-00003.safetensors",
      "model-00003-of-00003.safetensors",
    ]
  output_dir: ${output_dir}/value
  model_type: REWARD

# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/RM-Mistral-7B/
  checkpoint_files:
    [
      "model-00001-of-00003.safetensors",
      "model-00002-of-00003.safetensors",
      "model-00003-of-00003.safetensors",
    ]
  output_dir: ${output_dir}/value
  model_type: REWARD

resume_from_checkpoint: False
seed: null
shuffle: True

# Training env
device: cuda

# Training arguments
batch_size: 128
num_steps: 10000
ppo_epochs: 1
ppo_batch_size: 128
gradient_accumulation_steps: 1  # Use to increase effective batch size

# Memory management and performance
compile: True  # torch.compile the model + loss, True increases speed + decreases memory
optimizer:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 3e-6
optimizer_in_bwd: True  # True saves memory. Requires gradient_accumulation_steps=1
log_peak_memory_stats: True
enable_activation_checkpointing: True # True reduces memory
enable_kv_cache: True

# Reduced precision
dtype: bf16

# batch size for forward pass during generation
forward_batch_size: 128
max_generated_tokens: 58
temperature: 0.7
top_k: null

# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3

# tokens to consider as "end of sequence" tokens
stop_token_ids: [
    2, # eos_id
    28723, # mistral "." token
  ]
whiten_rewards: False

# GAE hyperparameters
gamma: 1
lmbda: 0.95

# PPO hyperparameters
loss:
  _component_: torchtune.rlhf.loss.PPOLoss
  epsilon: 0.2
  value_coeff: 0.1
  value_clip_range: 0.2
kl_coeff: 0.01

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs

log_every_n_steps: 1

profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: True
  with_stack: False
  record_shapes: False
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 3
  active_steps: 3
  num_cycles: 1
