import torch
class AWDPOConfig:
    def __init__(self, **kwargs):
        self.output_dir = kwargs.get("output_dir", "outputs")
        self.run_name = kwargs.get("run_name", "custom_grpo")
        self.learning_rate = kwargs.get("learning_rate", 1e-5)
        self.weight_decay = kwargs.get("weight_decay", 0.01)
        self.warmup_steps = kwargs.get("warmup_steps", 50)
        self.num_generations = kwargs.get("num_generations", 1)
        self.max_prompt_length = kwargs.get("max_prompt_length", 256)
        self.max_completion_length = kwargs.get("max_completion_length", 256)
        self.num_train_epochs = kwargs.get("num_train_epochs", 1)
        self.gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps", 1)
        self.logging_steps = kwargs.get("logging_steps", 1)
        self.save_steps = kwargs.get("save_steps", 50)
        self.max_steps = kwargs.get("max_steps", 1000)
        self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        self.temperature = kwargs.get("temperature", 0.2)
        self.num_generated_samples_to_view = kwargs.get("num_generated_samples_to_view", 10)
        self.bf16 = kwargs.get("bf16", True)
        self.per_device_train_batch_size = kwargs.get("per_device_train_batch_size", 4)
        self.use_flash_attn_2 = kwargs.get("use_flash_attn_2", False)
        self.use_vllm = kwargs.get("use_vllm", False)
        self.vllm_device = kwargs.get("vllm_device", "cuda:0")
        self.vllm_gpu_memory_utilization = kwargs.get("vllm_gpu_memory_utilization", 0.2)
        self.vllm_dtype = kwargs.get("vllm_dtype", "bfloat16" if self.bf16 else "float16")
        self.vllm_max_model_len = kwargs.get("vllm_max_model_len", 512)
        self.eval_no_shot = kwargs.get("eval_no_shot", False)
        self.eval_interval = kwargs.get("eval_interval", 100)
        self.eval_temperature = kwargs.get("eval_temperature", 0.7)
        self.repetition_penalty = kwargs.get("repetition_penalty", 1.2)
        self.top_p = kwargs.get("top_p", 0.9)
        self.top_k = kwargs.get("top_k", 50)
        self.min_p = kwargs.get("min_p", 0.0)
        self.guided_decoding_regex = kwargs.get("guided_decoding_regex", None)
        self.vllm_server_host = kwargs.get("vllm_server_host", "0.0.0.0")
        self.vllm_server_port = kwargs.get("vllm_server_port", 8000)
        self.vllm_server_timeout = kwargs.get("connection_timeout", 60.0)
        self.use_reference_model = kwargs.get("use_reference_model", False)
        self.use_lora = kwargs.get("use_lora", False)
        self.lora_rank = kwargs.get("lora_rank", 64)
        self.lora_alpha = kwargs.get("lora_alpha", 64)
        self.lora_dropout = kwargs.get("lora_dropout", 0.05)
        self.policy_reset = kwargs.get("policy_reset", False)
        self.use_advantage_scaling = kwargs.get("use_advantage_scaling", False)