from deepspeed.profiling.flops_profiler.profiler import get_model_profile
import torch
from models.mmsiamese import CoMM, MMSiamese
from models.vit import VisionTransformer
from models.transformer import LanguageEncoder
from models.mmfusion import MMFusion
from models.clip import CLIP
from models.blip2 import Blip2, Blip2VisionTransformer, Blip2LanguageTransformer
from deepspeed.accelerator import get_accelerator
"""
    This script computes the number of parameters, FLOPs (floating-point operations for a given model),
    MACs of the following models: CoMM, CLIP, BLIP-2 
"""

with get_accelerator().device(0):
    batch_size = 16
    images = torch.ones((batch_size, 3, 224, 224))
    text = [77 * "this " for _ in range(batch_size)] # CLIP is limited to max seq. length = 77

    # CoMM w/ CLIP backbones profiling using MMFusion model
    vision = VisionTransformer("vit_base_patch32_clip_224.openai",
                               pretrained=True, output_value="token_embeddings",
                               freeze=True)
    language = LanguageEncoder("clip-ViT-B-32-multilingual-v1",
                               output_value="token_embeddings", use_dataset_cache=False,
                               freeze=True)
    fusion = MMFusion([vision, language], [None, None],
                      embed_dim=768, fusion="concat", pool="cls",
                      n_heads=8, n_layers=1, dropout=0, drop_prob=0)

    get_model_profile(fusion, args=[[images, text]], detailed=False, output_file="CoMM-CLIP.log")

    # CLIP profiling
    clip = CLIP(vision, language, optim_kwargs={"lr": 1e-4, "weight_decay": 1e-2})
    get_model_profile(clip, kwargs=dict(image=images, text=text), detailed=False, output_file="CLIP.log")

    # BLIP2 profiling
    blip2 = Blip2()
    get_model_profile(blip2, kwargs=dict(image=images, text=text), detailed=False, output_file="BLIP2.log")

    # CoMM w/ BLIP2 backbones profiling using MMFusion model
    vision = Blip2VisionTransformer()
    language = Blip2LanguageTransformer()
    fusion = MMFusion([vision, language], [None, None],
                      embed_dim=768, fusion="concat", pool="cls",
                      n_heads=8, n_layers=1, dropout=0, drop_prob=0)
    get_model_profile(fusion, args=[[images, text]], detailed=False, output_file="CoMM-BLIP2.log")





