# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from transformers import is_bitsandbytes_available

from ..core import flatten_dict


@dataclass
class DDPOConfig:
    r"""
    Configuration class for the [`DDPOTrainer`].

    Using [`~transformers.HfArgumentParser`] we can turn this class into
    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
    command line.

    Parameters:
        exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
            Name of this experiment (by default is the file name without the extension name).
        run_name (`str`, *optional*, defaults to `""`):
            Name of this run.
        seed (`int`, *optional*, defaults to `0`):
            Random seed.
        log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
            Log with either 'wandb' or 'tensorboard', check
            https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
        tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
            Keyword arguments for the tracker (e.g. wandb_project).
        accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
            Keyword arguments for the accelerator.
        project_kwargs (`Dict`, *optional*, defaults to `{}`):
            Keyword arguments for the accelerator project config (e.g. `logging_dir`).
        tracker_project_name (`str`, *optional*, defaults to `"trl"`):
            Name of project to use for tracking.
        logdir (`str`, *optional*, defaults to `"logs"`):
            Top-level logging directory for checkpoint saving.
        num_epochs (`int`, *optional*, defaults to `100`):
            Number of epochs to train.
        save_freq (`int`, *optional*, defaults to `1`):
            Number of epochs between saving model checkpoints.
        num_checkpoint_limit (`int`, *optional*, defaults to `5`):
            Number of checkpoints to keep before overwriting old ones.
        mixed_precision (`str`, *optional*, defaults to `"fp16"`):
            Mixed precision training.
        allow_tf32 (`bool`, *optional*, defaults to `True`):
            Allow `tf32` on Ampere GPUs.
        resume_from (`str`, *optional*, defaults to `""`):
            Resume training from a checkpoint.
        sample_num_steps (`int`, *optional*, defaults to `50`):
            Number of sampler inference steps.
        sample_eta (`float`, *optional*, defaults to `1.0`):
            Eta parameter for the DDIM sampler.
        sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
            Classifier-free guidance weight.
        sample_batch_size (`int`, *optional*, defaults to `1`):
            Batch size (per GPU) to use for sampling.
        sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
            Number of batches to sample per epoch.
        train_batch_size (`int`, *optional*, defaults to `1`):
            Batch size (per GPU) to use for training.
        train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
            Use 8bit Adam optimizer from bitsandbytes.
        train_learning_rate (`float`, *optional*, defaults to `3e-4`):
            Learning rate.
        train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
            Adam beta1.
        train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
            Adam beta2.
        train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
            Adam weight decay.
        train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
            Adam epsilon.
        train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
            Number of gradient accumulation steps.
        train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
            Maximum gradient norm for gradient clipping.
        train_num_inner_epochs (`int`, *optional*, defaults to `1`):
            Number of inner epochs per outer epoch.
        train_cfg (`bool`, *optional*, defaults to `True`):
            Whether to use classifier-free guidance during training.
        train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
            Clip advantages to the range.
        train_clip_range (`float`, *optional*, defaults to `1e-4`):
            PPO clip range.
        train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
            Fraction of timesteps to train on.
        per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
            Whether to track statistics for each prompt separately.
        per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
            Number of reward values to store in the buffer for each prompt.
        per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
            Minimum number of reward values to store in the buffer.
        async_reward_computation (`bool`, *optional*, defaults to `False`):
            Whether to compute rewards asynchronously.
        max_workers (`int`, *optional*, defaults to `2`):
            Maximum number of workers to use for async reward computation.
        negative_prompts (`str`, *optional*, defaults to `""`):
            Comma-separated list of prompts to use as negative examples.
        push_to_hub (`bool`, *optional*, defaults to `False`):
            Whether to push the final model checkpoint to the Hub.
    """

    exp_name: str = field(
        default=os.path.basename(sys.argv[0])[: -len(".py")],
        metadata={"help": "Name of this experiment (by default is the file name without the extension name)."},
    )
    run_name: str = field(
        default="",
        metadata={"help": "Name of this run."},
    )
    seed: int = field(
        default=0,
        metadata={"help": "Random seed."},
    )
    log_with: Optional[str] = field(
        default=None,
        metadata={
            "help": "Log with either 'wandb' or 'tensorboard'.",
            "choices": ["wandb", "tensorboard"],
        },
    )
    tracker_kwargs: dict = field(
        default_factory=dict,
        metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)."},
    )
    accelerator_kwargs: dict = field(
        default_factory=dict,
        metadata={"help": "Keyword arguments for the accelerator."},
    )
    project_kwargs: dict = field(
        default_factory=dict,
        metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)."},
    )
    tracker_project_name: str = field(
        default="trl",
        metadata={"help": "Name of project to use for tracking."},
    )
    logdir: str = field(
        default="logs",
        metadata={"help": "Top-level logging directory for checkpoint saving."},
    )
    num_epochs: int = field(
        default=100,
        metadata={"help": "Number of epochs to train."},
    )
    save_freq: int = field(
        default=1,
        metadata={"help": "Number of epochs between saving model checkpoints."},
    )
    num_checkpoint_limit: int = field(
        default=5,
        metadata={"help": "Number of checkpoints to keep before overwriting old ones."},
    )
    mixed_precision: str = field(
        default="fp16",
        metadata={"help": "Mixed precision training."},
    )
    allow_tf32: bool = field(
        default=True,
        metadata={"help": "Allow `tf32` on Ampere GPUs."},
    )
    resume_from: str = field(
        default="",
        metadata={"help": "Resume training from a checkpoint."},
    )
    sample_num_steps: int = field(
        default=50,
        metadata={"help": "Number of sampler inference steps."},
    )
    sample_eta: float = field(
        default=1.0,
        metadata={"help": "Eta parameter for the DDIM sampler."},
    )
    sample_guidance_scale: float = field(
        default=5.0,
        metadata={"help": "Classifier-free guidance weight."},
    )
    sample_batch_size: int = field(
        default=1,
        metadata={"help": "Batch size (per GPU) to use for sampling."},
    )
    sample_num_batches_per_epoch: int = field(
        default=2,
        metadata={"help": "Number of batches to sample per epoch."},
    )
    train_batch_size: int = field(
        default=1,
        metadata={"help": "Batch size (per GPU) to use for training."},
    )
    train_use_8bit_adam: bool = field(
        default=False,
        metadata={"help": "Use 8bit Adam optimizer from bitsandbytes."},
    )
    train_learning_rate: float = field(
        default=3e-4,
        metadata={"help": "Learning rate."},
    )
    train_adam_beta1: float = field(
        default=0.9,
        metadata={"help": "Adam beta1."},
    )
    train_adam_beta2: float = field(
        default=0.999,
        metadata={"help": "Adam beta2."},
    )
    train_adam_weight_decay: float = field(
        default=1e-4,
        metadata={"help": "Adam weight decay."},
    )
    train_adam_epsilon: float = field(
        default=1e-8,
        metadata={"help": "Adam epsilon."},
    )
    train_gradient_accumulation_steps: int = field(
        default=1,
        metadata={"help": "Number of gradient accumulation steps."},
    )
    train_max_grad_norm: float = field(
        default=1.0,
        metadata={"help": "Maximum gradient norm for gradient clipping."},
    )
    train_num_inner_epochs: int = field(
        default=1,
        metadata={"help": "Number of inner epochs per outer epoch."},
    )
    train_cfg: bool = field(
        default=True,
        metadata={"help": "Whether to use classifier-free guidance during training."},
    )
    train_adv_clip_max: float = field(
        default=5.0,
        metadata={"help": "Clip advantages to the range."},
    )
    train_clip_range: float = field(
        default=1e-4,
        metadata={"help": "PPO clip range."},
    )
    train_timestep_fraction: float = field(
        default=1.0,
        metadata={"help": "Fraction of timesteps to train on."},
    )
    per_prompt_stat_tracking: bool = field(
        default=False,
        metadata={"help": "Whether to track statistics for each prompt separately."},
    )
    per_prompt_stat_tracking_buffer_size: int = field(
        default=16,
        metadata={"help": "Number of reward values to store in the buffer for each prompt."},
    )
    per_prompt_stat_tracking_min_count: int = field(
        default=16,
        metadata={"help": "Minimum number of reward values to store in the buffer."},
    )
    async_reward_computation: bool = field(
        default=False,
        metadata={"help": "Whether to compute rewards asynchronously."},
    )
    max_workers: int = field(
        default=2,
        metadata={"help": "Maximum number of workers to use for async reward computation."},
    )
    negative_prompts: str = field(
        default="",
        metadata={"help": "Comma-separated list of prompts to use as negative examples."},
    )
    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether to push the final model checkpoint to the Hub."},
    )

    def to_dict(self):
        output_dict = {}
        for key, value in self.__dict__.items():
            output_dict[key] = value
        return flatten_dict(output_dict)

    def __post_init__(self):
        if self.train_use_8bit_adam and not is_bitsandbytes_available():
            raise ImportError(
                "You need to install bitsandbytes to use 8bit Adam. "
                "You can install it with `pip install bitsandbytes`."
            )
