from cv2 import subtract
from trl import GRPOConfig
from dataclasses import dataclass, field

@dataclass
class RefGuidedVIConfig(GRPOConfig):

    prob_reward_weight: float = field(
        default=1.0,
        metadata={"help": "The weight for the reference answer probability reward."},
    )

    z_kl_beta: float = field(
        default=1e-3,
        metadata={"help": "The weight for the z kl."},
    )

    sft_beta: float = field(
        default=1e0,
        metadata={"help": "The weight for the sft loss."},
    )

    kl_estimator: str = field(
        default="k3",
        metadata={"help": "k1 or k3, The KL divergence estimator to use."},
    )

    min_r: float = field(
        default=1e-2,
        metadata={"help": "The minimum value for the r in logr in kl computing."},
    )

    max_r: float = field(
        default=100.0,
        metadata={"help": "The maximum value for the r in logr in kl computing."},
    )

    format_wrong_reward: float = field(
        default=-1.0,
        metadata={"help": "Reward for format errors in the response."},
    )

    # reference_leakage_reward: float = field(
    #     default=-0.5,
    #     metadata={"help": "Reward for reference leakage in the response."},
    # )

    # invalid_reasoning_in_response_reward: float = field(
    #     default=-0.5,
    #     metadata={"help": "Reward for invalid reasoning in the response."},
    # )
    
    min_prob_reward_ratio: float = field(
        default=1,
        metadata={"help": "The minimum value for the probability reward ratio."},
    )

    max_prob_reward_ratio: float = field(
        default=10.0,
        metadata={"help": "The maximum value for the probability reward ratio."},
    )

    prob_model: str = field(
        default="self",
        metadata={"help": "The model to use for computing the probability reward. Can be 'self' or 'ref'."},
    )

    z_kl_constraint_coef: float = field(
        default=0.5,
        metadata={"help": "The coefficient for the z kl constraint."},
    )

    z_kl_learning_coef: float = field(
        default=0.5,
        metadata={"help": "The coefficient for the z kl learning."},
    )

    p_grpo_loss_coef: float = field(
        default=0.5,
        metadata={"help": "The coefficient for the P grpo loss."},
    )

    q_grpo_loss_coef: float = field(
        default=0.5,
        metadata={"help": "The coefficient for the Q grpo loss."},
    )

    prob_reward_baseline: str = field(
        default="naive_group_mean",
        metadata={"help": "prob reward baseline."},
    )

    z_kl_sample_weight: str = field(
        default="clipped_prob_gain",
        metadata={"help": "z kl sample weight."},
    )

    answer_prefix: str = field(
        default="simple_prefix",
        metadata={"help": "The prefix before the answer, can be 'simple_prefix' or 'none'."},
    )
    def __post_init__(self):
        super().__post_init__()

