import os
from dataclasses import dataclass, field
from ..trainer.rloo_config import RLOOConfig


@dataclass
class RLOOConfigWithFeedback(RLOOConfig):
    rm_with_feedback: bool = field(
        default=False,
        metadata={"help": "Use reward model with feedback."},
    )
    rm_lr: float = field(
        default=1e-6,
        metadata={"help": "Learning rate of the reward model."},
    )
    lqh: bool = field(
        default=False,
        metadata={"help": "Use last hiddenstate of  the last query token, currently only offer extra message for the reward model"},
    ) # extra return of batch generation with hiddenstates
    pad_h: bool = field(
        default=True,
        metadata={"help": "ignore the pad tokens in hiddenstates"},
    ) 
    fw: float=field(
        default=0.1,
        metadata={"help": "feedback weight"},
    )
    agg: str=field(
        default="mlp",
        metadata={"help": "the way of aggregate feedback"},
    )
    filter_length: int=field(
        default=-1,
        metadata={"help": "the length of the filter"},
    )
    debug: bool=field(
        default=False,
        metadata={"help": "will set the model layers to 3"},
    )
    enable_lm: bool=field(
        default=False,
        metadata={"help": "whether to use multi view loss item to the Loss"}
    )
    enable_le: bool=field(
        default=False,
        metadata={"help": "whether to use reward entropy loss item to the Loss"}
    )
    gradient_checkpointing:bool=field(
        default=False,
        metadata={"help": "whether to open model checkpointing"}
    )
    dynamic_fw:bool=field(
        default=False,
        metadata={"help": "whether to enlarge fw gradually"}
    )
    save_reward_model:bool=field(
        default=False,
        metadata={"help": "whether to save reward model"}
    )
    le_weight: float=field(
        default=0.1,
        metadata={"help": "feedback weight"},
    )
    advantage_estimate: str=field(
        default="rloo",
        metadata={"help": "the way of advantage estimate"},
    )
    h_mask_type: str=field(
        default="rm",
        metadata={"help": "the way of advantage estimate"},
    )
    train_rm:bool=field(
        default=True,
        metadata={"help": "whether to train reward model"}
    )



    # keep_raw_score: bool=field(
    #     default=False,
    #     metadata={"help": "whether to keep the raw score of reward model"},
    # )
