# coding=utf-8
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from sympy import Float
import trl

@dataclass
class GRPOConfig(trl.GRPOConfig):
    """
    args for callbacks, benchmarks etc
    """

    benchmarks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
    )
    callbacks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
    )
    system_prompt: Optional[str] = field(
        default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
    )
    hub_model_revision: Optional[str] = field(
        default="main", metadata={"help": "The Hub model branch to push the model to."}
    )
    overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
    push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    wandb_entity: Optional[str] = field(
        default=None,
        metadata={"help": ("The entity to store runs under.")},
    )
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )
    weight : Optional[float] = field(
        default=0.5,
        metadata={'help': 'help'}
    )
    guidance_scale : Optional[float] = field(
        default=None,
        metadata={'help': 'help'}
    )
    generate_with_cfg: bool = field(
        default=False, 
        metadata={"help": "generate samples with cfg"}
    )
    use_clip_score: Optional[bool] = field(
        default=True,
        metadata={'help': 'use clip score in reward model, default to be true'}
    )
    reverse: bool = field(default=False, metadata={"help": "whether to use the reversed token"})
    set_epsilon: bool = field(
        default=False, 
        metadata={"help": "whether to set_epsilon."}
        )
    epsilon: Float = field(
        default=0.2,
        metadata={'help': 'coefficient that clip the loss'}
    )
    epsilon_low: float = field(
        default=0.2,
        metadata={'help': 'lower bound that clip the loss'}
    )
    epsilon_high: float = field(
        default=0.4,
        metadata={'help': 'upper bound that clip the loss'}
    )
    warm_up_steps: int = field(
        default=45,
        metadata={'help': 'warm up steps'}
    )
    convert_steps: int = field(
        default=500,
        metadata={'help': 'steps that lr begins to drop'}
    )
    convert_lr: Float = field(
        default=1e-6,
        metadata={'help': 'lr that the optimizer drops to after the convert_steps'}
    )
    min_lr: Float = field(
        default=2e-7,
        metadata={'help': 'min lr'}
    )
    model_path: str = field(
        default='',
        metadata={'help': 'default model path'}
    )
    use_self_lr_scheduler: bool = field(
        default=False, 
        metadata={"help": "Whether to overwrite the Hub revision."}
    )
    tokenizer_path: str = field(
        default=None, 
        metadata={'help': 'tokenizer path'}
    )
    reward_list: list[str] = field(
        default_factory=lambda: [], 
        metadata={'help': 'the list of reward functions which will be used in training'}
    )
    llama_tokenizer_path: str = field(
        default='', 
        metadata={'help': 'llama3 tokenizer path'}
    )
    task_type: str = field(
        default='t2i', 
        metadata={'help': 'training task type'}
    )
    selftok_tokenizer_path: str = field(
        default='', 
        metadata={'help': 'selftok tokenizer path'}
    )
    port: int = field(
        default=56950, 
        metadata={'help': 'the port of selftok tokenizer'}
    )
    selftok_config: str = field(
        default='', 
        metadata={'help': 'the config of selftok tokenizer'}
    )
    use_std_reward: bool = field(
        default=True, 
        metadata={'help': 'whether to use the standard general reward instead of map it to [-1, 1]'}
    )
    dataset_type: str = field(
        default=None, 
        metadata={'help': 'determine which dataset to use'}
    )
    data_source: list[str] = field(
        default_factory=lambda: [], 
        metadata={'help': 'the list of data that will be used in training'}
    )
    datap: list[str] = field(
        default_factory=lambda: [], 
        metadata={'help': 'the path of data, should be consistent with the order of items in data_source'}
    )
    mox_path: str = field(
        default=None, 
        metadata={'help': 'save the model to s3 path'}
    )
    use_api: bool = field(
        default=None, 
        metadata={'help': 'use api for reward modl'}
    )
    reward_model_path: str = field(
        default=None, 
        metadata={'help': 'the path of reward model'}
    )
    data_config_path: str = field(
        default=None, 
        metadata={'help': 'the yml path of data when train task is editing'}
    )
    text_vocab_size: int = field(
        default=128256, 
        metadata={'help': 'the text vocabulary size'}
    )
    image_vocab_size: int = field(
        default=32768, 
        metadata={'help': 'the image vocabulary size'}
    )
    image_save_path: str = field(
        default=None, 
        metadata={'help': 'path to save image during training'}
    )
    cfg_type: str = field(
        default='fix',
        metadata={'help': 'cfg policy when training and sampling: "fix", "adaptive"'}
    )
    entropy_bound: int = field(
        default=2,
        metadata={'help': 'entropy bound for cfg'}
    )
    min_cfg: int = field(
        default=1,
        metadata={'help': 'min cfg'}
    )
    use_raw_caption: bool = field(
        default=True,
        metadata={'help': 'whether to use raw caption in pangu dataset'}
    )
    use_extented_raw_caption: bool = field(
        default=False,
        metadata={'help': 'whether to use extended caption in pangu dataset'}
    )
    num_iterations: int = field(
        default=1,
        metadata={'help': '𝜇 in the GRPO paper'}
    )
