

from tabicl.config.config_pretrain import ConfigPretrain


def calculate_tab2d_forward_flops(cfg: ConfigPretrain):

    n_steps = cfg.optim.steps
    batch_size = cfg.optim.batch_size * cfg.optim.gradient_accumulation_steps * len(cfg.devices)

    min_support_size = cfg.data.min_samples_support
    max_support_size = cfg.data.max_samples_support
    min_features = cfg.data.min_features
    max_features = cfg.data.max_features
    n_classes = cfg.data.max_classes
    
    # Y feature is added to the X features, so +1
    avg_features = (min_features + max_features) // 2 + 1
    avg_obs_support = (min_support_size + max_support_size) // 2
    avg_obs_query = cfg.data.n_samples_query
    avg_obs = avg_obs_support + avg_obs_query

    dim = cfg.model['dim']
    layers = cfg.model['n_layers']

    # Preprocessing/Quantile flops are ignored

    # The embedding is from 1 to dim
    flops_embedding = avg_obs * avg_features * 1 * dim

    # The attention over the observations has a self-attention and a cross-attention
    flops_layer_attn_obs_QKVO = layers * 4 * avg_features * avg_obs * dim * dim
    flops_layer_attn_obs_support = layers * 2 * avg_obs_support * avg_obs_support * avg_features * dim
    flops_layer_attn_obs_query = layers * 2 * avg_obs_support * avg_obs_query * avg_features * dim 

    # The attention over the features has two self-attentions
    flops_layer_attn_feat_QKVO = layers * 4 * avg_obs * avg_features * dim * dim
    flops_layer_attn_feat_support = layers * 2 * avg_features * avg_features * avg_obs_support * dim
    flops_layer_attn_feat_query = layers * 2 * avg_features * avg_features * avg_obs_query * dim

    n_linear_modules_per_layer = 4
    flops_layer_linear = layers * n_linear_modules_per_layer * avg_obs * avg_features * dim * (dim * 4)

    flops_final_layer = avg_obs_query * avg_features * dim * n_classes


    flops_model = (
        flops_embedding 
        + flops_layer_attn_obs_QKVO
        + flops_layer_attn_obs_support
        + flops_layer_attn_obs_query
        + flops_layer_attn_feat_QKVO
        + flops_layer_attn_feat_support
        + flops_layer_attn_feat_query
        + flops_layer_linear 
        + flops_final_layer
    )

    flops_total = 2 * flops_model * n_steps * batch_size
    
    return flops_total

    



