# this file is expected to be in the root of
# deit repo
import torch
import time
from timm.models import create_model
#import prettytable
import numpy as np
from prettytable import PrettyTable
Table = PrettyTable()
Table.field_names = ["Model", "Fixations", "Images/Sec"]
#import Original_models
#import V3_throughput_models
'''import NEW_models_V3'''
#import Original_models
#import throughput_models
import Step_1_Flops_patchconvnet_models

bs = 256 #1024 #2048 #4096 #8704
#bs = 8704

_MODEL_PARAMS = {
    #"tf_efficientnet_b7": (70, 600),
    #"tf_efficientnet_b6": (80, 528),
    #"tf_efficientnet_b5": (160, 456),
    #"tf_efficientnet_b4": (240, 380),
    #"tf_efficientnet_b3": (350, 300),
    #"tf_efficientnet_b2": (700, 260),
    #"tf_efficientnet_b1": (900, 240),
    #"tf_efficientnet_b0": (1000, 224),

    #"deit_tiny_patch16_224": (3800, 224),
    #"deit_small_patch16_224": (3100, 224),
    #"deit_base_patch16_224": (1600, 224)

    #"vit_base_patch16_384": (356, 384),
    #"vit_large_patch16_384": (256, 384),

    #"deit_tiny_patch16_224": (512, 224),
    #"deit_foveated_small_patch16_224": (128, 224),
    #"deit_foveated_tiny_patch16_224": (1, 224),
    #"deit_small_patch16_224": (8192, 224),
    #"deit_foveated_small_patch16_224": (8192, 224),
    


    #"deit_small_patch16_224": (bs, 224),
    #"deit_foveated_small_patch16_224": (bs, 224),

    "S60": (bs, 224),

    #"deit_small_patch16_224_FS_1x1_3x3": (bs, 224),
    #"deit_small_patch16_224_FS_3x3_3x3": (bs, 224),
    #"deit_small_patch16_224_FS_1x1_5x5": (bs, 224),
    #"deit_small_patch16_224_FS_5x5_5x5": (bs, 224),
}


@torch.no_grad()
def compute_throughput(model_name, num_fix):
    torch.cuda.empty_cache()
    warmup_iters = 3
    num_iters = 30
    device = torch.device('cuda')

    model = create_model(model_name, pretrained=False)
    model.save_images = False
    model.eval()
    model.to(device)

    flag_partial = False #True
    model.flag_partial = flag_partial

    if model.flag_partial:
        model_FS = create_model('deit_foveated_small_patch16_224_PartialFS', pretrained=False)
        #model_FS = create_model('deit_small_patch16_224_FS', pretrained=False)
        model_FS.eval()
        model_FS.to(device)
    timing = []


    batch_size, resolution = _MODEL_PARAMS[model_name]

    if num_fix>0:
        model.num_fixations = num_fix

    inputs = torch.randn(batch_size, 3, resolution, resolution, device=device)

    # warmup
    for _ in range(warmup_iters):
        if flag_partial:
            if num_fix==0:
                model(inputs)
            else:
                P, feat = model(inputs)
                model_FS(feat)
        else:
            model(inputs)

    torch.cuda.synchronize()
    for _ in range(num_iters):
        start = time.time()
        if flag_partial:
            if num_fix==0:
                model(inputs)
            else:
                P, feat = model(inputs)
                model_FS(feat)
        else:
            model(inputs)
        torch.cuda.synchronize()
        timing.append(time.time() - start)

    timing = torch.as_tensor(timing, dtype=torch.float32)
    return batch_size / timing.mean()


if __name__ == "__main__":
    for model_name in _MODEL_PARAMS.keys():
        if True: #model_name=="deit_foveated_small_patch16_224":
            for num_fix in np.arange(1,6):
                imgs_per_sec = compute_throughput(model_name, num_fix)
                Table.add_row([model_name, num_fix, imgs_per_sec.item()])
        else:
            imgs_per_sec = compute_throughput(model_name, 0)
            Table.add_row([model_name, 0, imgs_per_sec.item()])
        print(f"{model_name}: {imgs_per_sec:.2f}")
        print(Table)
