from dataclasses import dataclass, field
from typing import Optional

import transformers
from trl import ModelConfig, ScriptArguments


@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    Script arguments for the GRPO training script.

    Args:
        reward_funcs (`list[str]`):
            List of reward functions. Possible values: 'accuracy', 'format'.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
    )
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
    )
    max_anyres_num: Optional[int] = field(
        default=12,
        metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
    )
    image_root: Optional[str] = field(
        default=None,
        metadata={"help": "Root directory of the image"},
    )
    task_type: str = field(
        default="think",
        metadata={"help": "Type of task to run. Options include 'nothink', 'think', 'answer', etc. This can be used to control the behavior of the training script."},
    )


@dataclass
class GRPOConfig(transformers.TrainingArguments):
    r"""
    Configuration class for the [`GRPOTrainer`].

    Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
    [`~transformers.TrainingArguments`] documentation.

    Using [`~transformers.HfArgumentParser`] we can turn this class into
    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
    command line.

    Parameters:
        > Parameters that control the model and reference model

        model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
            Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
            argument of the [`GRPOTrainer`] is provided as a string.

        > Parameters that control the data preprocessing

        remove_unused_columns (`bool`, *optional*, defaults to `False`):
            Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
            requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
        max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
            Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
        num_think_samples (`int`, *optional*, defaults to `4`):
            Number of think processes to sample per prompt (K in H-GRPO). Each think will be used to generate
            `num_answer_samples_per_think` answers.
        num_answer_samples_per_think (`int`, *optional*, defaults to `2`):
            Number of answer samples to generate per think process (M in H-GRPO). This allows evaluating multiple
            answers per think while only generating the think once.
        num_generations (`int` or `None`, *optional*, defaults to `8`):
            Total number of answer samples per prompt (K*M in H-GRPO). The global batch size (num_processes * per_device_batch_size)
            must be divisible by this value.
        max_completion_length (`int` or `None`, *optional*, defaults to `256`):
            Maximum length of the generated completion.
        ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
            This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
            improving generation speed. However, disabling this option allows training models that exceed the VRAM
            capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
            with vLLM generation.

        > Parameters that control generation

        temperature (`float`, defaults to `0.9`):
            Temperature for sampling. The higher the temperature, the more random the completions.
        top_p (`float`, *optional*, defaults to `1.0`):
            Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
            `1.0` to consider all tokens.
        top_k (`int` or `None`, *optional*, defaults to `50`):
            Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
            disabled.
        min_p (`float` or `None`, *optional*, defaults to `None`):
            Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
            value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
        repetition_penalty (`float`, *optional*, defaults to `1.0`):
            Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
            Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
            tokens.
        cache_implementation (`str` or `None`, *optional*, defaults to `None`):
            Implementation of the cache method for faster generation when use_vllm is set to False.

        > Parameters that control generation acceleration powered by vLLM

        use_vllm (`bool`, *optional*, defaults to `False`):
            Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
            training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
        vllm_device (`str`, *optional*, defaults to `"auto"`):
            Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
            automatically select the next available GPU after the last one used for training. This assumes that
            training has not already occupied all available GPUs. If only one device is available, the device will be
            shared between both training and vLLM.
        vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
            Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
            device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
            improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
            during initialization.
        vllm_dtype (`str`, *optional*, defaults to `"auto"`):
            Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
            based on the model configuration. Find the supported values in the vLLM documentation.
        vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
            If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
            `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
            context size, which might be much larger than the KV cache, leading to inefficiencies.
        vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
            Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
            support this feature.
        vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
            Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.

        > Parameters that control the training

        learning_rate (`float`, *optional*, defaults to `1e-6`):
            Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
            [`~transformers.TrainingArguments`].
        beta (`float`, *optional*, defaults to `0.04`):
            KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
            speed, but may be numerically unstable for long training runs.
        num_iterations (`int`, *optional*, defaults to `1`):
            Number of iterations per batch (denoted as μ in the algorithm).
        epsilon (`float`, *optional*, defaults to `0.2`):
            Epsilon value for clipping.
        epsilon_high (`float` or `None`, *optional*, defaults to `None`):
            Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
            specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
        reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
            Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
            weighted equally with weight `1.0`.
        sync_ref_model (`bool`, *optional*, defaults to `False`):
            Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
            the `ref_model_mixup_alpha` parameter. This synchronization originites from the
            [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
        ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
            α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
            between the current policy and the previous reference policy during updates. The reference policy is
            updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
            must set `sync_ref_model=True`.
        ref_model_sync_steps (`int`, *optional*, defaults to `512`):
            τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
            frequently the current policy is synchronized with the reference policy. To use this parameter, you must
            set `sync_ref_model=True`.

        > Parameters that control the logging

        log_completions (`bool`, *optional*, defaults to `False`):
            Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
            installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
    """

    # Parameters that control the model and reference model
    model_init_kwargs: Optional[dict] = field(
        default=None,
        metadata={"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
                          "argument of the `GRPOTrainer` is provided as a string."},
    )

    # Parameters that control the data preprocessing
    # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
    # additional columns to compute the reward
    remove_unused_columns: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
                          "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."},
    )
    max_prompt_length: Optional[int] = field(
        default=512,
        metadata={"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."},
    )
    train_cls: Optional[str] = field(
        default="GRPO",
        metadata={"help": "trainer class"},
    )
    num_think_samples: int = field(
        default=4,
        metadata={"help": "Number of think processes to sample per prompt (K in H-GRPO). Each think will be used to "
                          "generate `num_answers_per_thinking` answers."},
    )
    num_answers_per_thinking: int = field(
        default=16,  # Default M value
        metadata={"help": "Number of answer samples to generate per think process (M in H-GRPO). This allows evaluating "
                          "multiple answers per think while only generating the think once."},
    )
    num_generations: Optional[int] = field(
        default=64,  # Should be num_think_samples * num_answers_per_thinking
        metadata={"help": "Total number of answer samples per prompt (K*M in H-GRPO). The global batch size "
                          "(num_processes * per_device_batch_size) must be divisible by this value."},
    )
    max_completion_length: Optional[int] = field(
        default=256,
        metadata={"help": "Maximum length of the generated completion."},
    )
    ds3_gather_for_generation: bool = field(
        default=True,
        metadata={
            "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
                    "generation, improving generation speed. However, disabling this option allows training models that "
                    "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
                    "is not compatible with vLLM generation."
        },
    )

    # Parameters that control generation
    temperature: float = field(
        default=0.9,
        metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
    )
    top_p: float = field(
        default=1.0,
        metadata={"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
                          "Set to 1.0 to consider all tokens."},
    )
    top_k: Optional[int] = field(
        default=50,
        metadata={"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
                          "top-k-filtering is disabled."},
    )
    min_p: Optional[float] = field(
        default=None,
        metadata={"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
                          "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."},
    )
    repetition_penalty: float = field(
        default=1.0,
        metadata={"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
                          "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
                          "to repeat tokens."},
    )
    cache_implementation: Optional[str] = field(
        default=None,
        metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
    )

    # Parameters that control generation acceleration powered by vLLM
    use_vllm: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
                          "unused for training, as vLLM will require one for generation. vLLM must be installed "
                          "(`pip install vllm`)."},
    )
    vllm_device: Optional[str] = field(
        default="auto",
        metadata={
            "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
                    "will automatically select the next available GPU after the last one used for training. This assumes "
                    "that training has not already occupied all available GPUs."
        },
    )
    vllm_gpu_memory_utilization: float = field(
        default=0.9,
        metadata={
            "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
                    "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
                    "size and thus improve the model's throughput. However, if the value is too high, it may cause "
                    "out-of-memory (OOM) errors during initialization."
        },
    )
    vllm_dtype: Optional[str] = field(
        default="auto",
        metadata={"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
                          "determined based on the model configuration. Find the supported values in the vLLM documentation."},
    )
    vllm_max_model_len: Optional[int] = field(
        default=None,
        metadata={
            "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
                    "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
                    "context size, which might be much larger than the KV cache, leading to inefficiencies."
        },
    )
    vllm_enable_prefix_caching: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
                          "the hardware support this feature."},
    )
    vllm_guided_decoding_regex: Optional[str] = field(
        default=None,
        metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
    )

    # Parameters that control the training
    learning_rate: float = field(
        default=1e-6,
        metadata={"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
                          "`transformers.TrainingArguments`."},
    )
    beta: float = field(
        default=0.04,
        metadata={"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
                          "training speed, but may be numerically unstable for long training runs."},
    )
    num_iterations: int = field(
        default=1,
        metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
    )
    epsilon: float = field(
        default=0.2,
        metadata={"help": "Epsilon value for clipping."},
    )
    epsilon_high: Optional[float] = field(
        default=None,
        metadata={"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
                          "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."},
    )
    reward_weights: Optional[list[float]] = field(
        default=None,
        metadata={"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
                          "rewards are weighted equally with weight `1.0`."},
    )
    sync_ref_model: bool = field(
        default=False,
        metadata={"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
                          "steps, using the `ref_model_mixup_alpha` parameter."},
    )
    ref_model_mixup_alpha: float = field(
        default=0.6,
        metadata={
            "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
                    "previous reference policy during updates. The reference policy is updated according to the equation: "
                    "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
        },
    )
    ref_model_sync_steps: int = field(
        default=512,
        metadata={"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
                          "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."},
    )

    # Parameters that control the logging
    log_completions: bool = field(
        default=False,
        metadata={"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
                          "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."},
    )
    need_gather: bool = field(
        default=False,
        metadata={"help": "Whether to gather rewards from all processes for global statistics. If set to `False`, the "
                          "rewards will not be gathered, which can be useful for debugging or when global statistics are not needed."},
    )
    ref_model_id: Optional[str] = field(
        default=None,
        metadata={"help": "Identifier of the reference model. If not specified, it defaults to the same as the model."},
    )
    stop_strings: Optional[list[str]] = field(
        default=None,
        metadata={"help": "List of strings where generation will stop. If `None`, generation stops when the model generates the EOS token."},
    )
