"""Configuration for CMI Loss implementation."""

from dataclasses import dataclass
from typing import Optional


@dataclass
class CMILossConfig:
    """Configuration parameters for CMI Loss training.
    
    Attributes:
        cmi_lambda: Final regularization strength (negative value for encouragement).
        cmi_lambda_start: Starting value for lambda during rampup phase.
        cmi_warmup_ratio: Fraction of training steps for warmup (standard SFT).
        cmi_rampup_ratio: Fraction of training steps for ramping up lambda.
        cmi_thinking_weight: Weight for thinking tokens in shortcut loss (0=mask, 1=full).
        cmi_loss_normalize: Whether to normalize losses to prevent scale mismatch.
        cmi_apply_to_harmful_only: Apply CMI only to harmful samples if True.
    """
    
    cmi_lambda: float = -0.1
    cmi_lambda_start: float = -0.01
    cmi_warmup_ratio: float = 0.3
    cmi_rampup_ratio: float = 0.5
    cmi_thinking_weight: float = 0.1
    cmi_loss_normalize: bool = True
    cmi_apply_to_harmful_only: bool = False
    
    def __post_init__(self):
        """Validate configuration parameters."""
        assert self.cmi_lambda <= 0, "cmi_lambda should be negative for encouragement"
        assert self.cmi_lambda_start <= 0, "cmi_lambda_start should be negative"
        assert abs(self.cmi_lambda_start) <= abs(self.cmi_lambda), "cmi_lambda_start should have smaller magnitude"
        assert 0 <= self.cmi_warmup_ratio < 1, "cmi_warmup_ratio must be in [0, 1)"
        assert 0 <= self.cmi_rampup_ratio < 1, "cmi_rampup_ratio must be in [0, 1)"
        assert self.cmi_warmup_ratio + self.cmi_rampup_ratio <= 1, "warmup + rampup cannot exceed 1"
        assert 0 <= self.cmi_thinking_weight <= 1, "cmi_thinking_weight must be in [0, 1]"