

from tabicl.config.config_pretrain import ConfigPretrain


def calculate_foundation_forward_flops(cfg: ConfigPretrain):

    n_steps = cfg.optim.steps
    batch_size = cfg.optim.batch_size * cfg.optim.gradient_accumulation_steps * len(cfg.devices)

    min_support_size = cfg.data.min_samples_support
    max_support_size = cfg.data.max_samples_support
    min_features = cfg.data.min_features
    max_features = cfg.data.max_features
    n_classes = cfg.data.max_classes
    query_size = cfg.data.n_samples_query

    avg_support_size = (min_support_size + max_support_size) // 2
    total_avg_size = avg_support_size + query_size
    total_max_size = max_support_size + query_size

    dim = cfg.model['dim_model']
    layers = cfg.model['n_layers']

    # Preprocessing/Quantile flops are ignored

    # Embedding of Y is ignored, only X embedding is calculated
    # We assume the sequence length will match the max_support size, even though it could be lower.
    flops_embedding = 2 * total_max_size * max_features * dim

    # Attention uses var_length sequence transformation, so it is calculated for the average size
    flops_layer_attn_QKVO = layers * 4 * total_avg_size * dim * dim
    # Support and query have their own pass throughs, but both use the QKt and the KtV matrix multiplication
    flops_layer_attn_support = layers * 2 * avg_support_size * avg_support_size * dim
    flops_layer_attn_query = layers * 2 * avg_support_size * query_size * dim 

    flops_layer_linear = layers * 2 * total_avg_size * dim * dim * 4

    # Final layer doesn't using the sequence length separation
    flops_final_layer_1 = total_max_size * dim * 4 * dim
    flops_final_layer_2 = total_max_size * 4 * dim * n_classes


    flops_model = (
        flops_embedding 
        + flops_layer_attn_QKVO 
        + flops_layer_attn_support 
        + flops_layer_attn_query 
        + flops_layer_linear 
        + flops_final_layer_1
        + flops_final_layer_2
    )

    flops_total = 2 * flops_model * n_steps * batch_size
    
    return flops_total

    



