from dataclasses import dataclass
from typing import Self

import torch
from omegaconf import DictConfig, OmegaConf


@dataclass
class ConfigGradScaler():
    enabled: bool
    scale_init: float
    scale_min: float
    growth_interval: int


    def __post_init__(self):
        assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
        assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
        assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
        assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"


@dataclass
class ConfigOptim():
    steps: int
    log_every_n_steps: int
    eval_every_n_steps: int
    batch_size: int
    gradient_accumulation_steps: int
    lr: float
    weight_decay: float
    beta1: float
    beta2: float
    warmup_steps: int
    cosine_scheduler: bool
    max_grad_norm: float
    label_smoothing: float
    use_pretrained_weights: bool
    path_to_weights: str
    precision: str
    grad_scaler: ConfigGradScaler


    @classmethod
    def from_hydra(cls, cfg_hydra: DictConfig) -> Self:

        grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
        cfg_dict: dict = OmegaConf.to_container(cfg_hydra)      # type: ignore 
        del cfg_dict["grad_scaler"]

        return cls(
            grad_scaler=grad_scaler,
            **cfg_dict
        )
            


    def __post_init__(self):
        
        assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
