# Model arguments
# model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
# model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
# model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct

# Best move from legal
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping/checkpoint-750
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6/Qwen_Qwen2.5-1.5B-Instruct_nl_random_80_10_10/checkpoint-750

# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-0.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-1800
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6/Qwen_Qwen2.5-0.5B-Instruct_nl_random_80_10_10/checkpoint-600

model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6/meta-llama_Llama-3.2-1B-Instruct_nl_canconical-symmetry-grouping/checkpoint-600
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6/meta-llama_Llama-3.2-1B-Instruct_nl_random_80_10_10/checkpoint-450

# Need to train Qwen2 as 2.5 is not supported in sae lens
# model_name_or_path: Qwen/Qwen2-1.5B-Instruct
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-300

# Special llamackpt for special token
# model_name_or_path: /mnt/shared/data/stlm-logic/models/meta-llama_Llama-3.2-1B-Instruct_special_updated
# model_name_or_path: /mnt/shared/data/stlm-logic/grpo_v2_expt_updated/meta-llama_Llama-3.2-1B-Instruct_special_canconical-symmetry-grouping_updated


# model_name_or_path: /mnt/data/data/stlm-logic/grpo_v2_expt/Qwen_Qwen2.5-1.5B-Instruct_special_random_80_10_10/checkpoint-200
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
# Default dataset
# dataset_name: open-r1/OpenR1-Math-220k
# dataset_prompt_column: problem

# Custom dataset training args
dataset_name: "tictactoe"
system_prompt: "You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
# experiment_mode: "legal_move"
experiment_mode: "best_move"
# representation_mode: "special"
representation_mode: "nl"
# dataset_type: "random_80_10_10"
dataset_type: "canconical-symmetry-grouping"

# Provide the column names which can be passed into generate_prompt_grpo
# past_data_column: past_timeline
# qa_pairs_column: qa_pairs

# JUDGE model for GRPO Config
# judge_model_name: unsloth/gemma-3-27b-it-bnb-4bit
# judge_model_url: https://fcbb3a58f730.ngrok.app/v1
# judge_api_key: "JUDGE_API_KEY"     # or hard‑code / vault ref



# GRPO trainer config
bf16: true
use_vllm: true
do_eval: false
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
hub_strategy: every_save
# learning_rate: 2.0e-05
learning_rate: 0.000001
log_completions: true
log_level: debug
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 8192 # TODO: Change
max_completion_length: 2048 # TODO: Change
max_steps: -1
num_generations: 8
num_train_epochs: 5
# TODO: update output dir
# Mimic min feasible dataset on which zeroshot eval was done, used to test grpo script since the actual train dataset is being prepared at the moment.
output_dir: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move
# output_dir: /mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move_from_legal
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: false
report_to:
- wandb
# TODO: update rewards
reward_funcs:
- legal_move
- format
- tag_count
reward_weights:
- 1.0
- 1.0
- 1.0
# TODO: save every 0.2 epochs
save_strategy: "steps"
save_steps: 150


save_total_limit: 30
seed: 42
warmup_ratio: 0.1
# vllm_port: 8001
vllm_server_port: 8006
# vllm_server_host: 0.0.0.0
# vllm_server_timeout: 120.0
# vllm_guided_decoding_regex: None
# run_name: "${MODEL_MARK}"