


from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import ModelName
from tabicl.models.foundation.foundation_flops_calculator import calculate_foundation_forward_flops
from tabicl.models.foundation_flash.foundation_flash_flops_calculator import calculate_foundation_forward_flops
from tabicl.models.tab2d.tab2d_flops_calculator import calculate_tab2d_forward_flops


class FLOPSNotImplementedError(Exception):
    pass


def calculate_flops(cfg: ConfigPretrain) -> float:

    match cfg.model_name:
        case ModelName.FOUNDATION:
            return calculate_foundation_forward_flops(cfg)
        case ModelName.FOUNDATION_FLASH:
            return calculate_foundation_forward_flops(cfg)
        case ModelName.TAB2D:
            return calculate_tab2d_forward_flops(cfg)
        case _:
            raise FLOPSNotImplementedError(f"Model {cfg.model_name} not supported for FLOPs calculation")