from dataclasses import dataclass, field
import torch


@dataclass
class BnBConfig:
    load_in_4bit: bool = True
    bnb_4bit_use_double_quant: bool = True
    bnb_4bit_quant_type: str = "nf4"
    bnb_4bit_compute_dtype: str = "bfloat16"
    bnb_4bit_quant_storage_dtype: str = "bfloat16"

    def __post_init__(self):
        if hasattr(torch, self.bnb_4bit_compute_dtype):
            self.bnb_4bit_compute_dtype = getattr(torch, self.bnb_4bit_compute_dtype)

        if hasattr(torch, self.bnb_4bit_quant_storage_dtype):
            self.bnb_4bit_quant_storage_dtype = getattr(
                torch, self.bnb_4bit_quant_storage_dtype
            )
