

def parse_quant_format(quant_format_enum):
    """Parse QuantFormat enum value (e.g., 'int8', 'fp4_e2m1', 'fp8_e4m3')"""
    format_str = quant_format_enum.value

    if format_str.startswith('int'):
        # int8, int4, etc.
        bits = int(format_str[3:])
        format_type = "int"
    elif format_str.startswith('fp'):
        # fp4_e2m1, fp8_e4m3, fp8_e5m2, etc.
        bits = int(format_str[2])
        format_type = "mxfp"
    else:
        raise ValueError(f"Unknown quantization format: {format_str}")

    return bits, format_type


def ensure_model_pad_id(model, tokenizer=None):
    if getattr(model.config, "pad_token_id", None) is None:
        pad_id = None
        if tokenizer is not None:
            pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
        pad_id = pad_id or getattr(model.config, "eos_token_id", None)

        if pad_id is None:
            raise ValueError("No pad/eos token id found to set as pad_token_id.")

        model.config.pad_token_id = pad_id
    return model