import os
import math
import time
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

import timm
from timm.models import create_model, VisionTransformer

class DummyInferenceParams:
    def __init__(self):
        self.key_value_memory_dict = {}
        self.seqlen_offset = 0

if __name__ == "__main__":
    inference_params = None
    from matmamba.matmamba2_vision import MatMamba2Vision
    # model_str = "vit"
    model_str = "matmamba"
    mixnmatch_dim = 64
    mixnmatch_dims = [mixnmatch_dim for _ in range(20)]
    repeats = 10
    y_wall_times = []
    y_memory_usages = []
    x_image_res = []
    image_res_list = [224, 256, 384, 512, 768, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584]
    for image_res in image_res_list:
        if model_str == "vit":
            # model = VisionTransformer(
            #     img_size=image_res,
            #     patch_size=16,
            #     # embed_dim=1024,
            #     # depth=20,
            # ).cuda()
            # print(model.patch_embed.num_patches)
            # print(model.patch_embed.img_size)

            # DEIT model
            model = timm.create_model(
                'deit_base_distilled_patch16_384.fb_in1k', pretrained=False, 
                img_size=image_res, num_classes=1000
            ).cuda()
            print(model.patch_embed.num_patches)
            print(model.patch_embed.img_size)
        else:
            model = MatMamba2Vision(
                d_model=512,
                n_layer=20,
                d_intermediate=0,
                n_classes=1000,
                patch_size=16,
                image_size=image_res,
                drop_path_rate=0,
                proj_drop_rate=0,
            ).cuda()

            inference_params = DummyInferenceParams()

            for layer in model.layers:
                layer.mixer.mixnmatch = True
                layer.mixer.mixnmatch_dims = mixnmatch_dims[layer.layer_idx]

        with torch.no_grad():
            model.eval()
            for idx in range(2):
                layer_times = []
                memory_usages = []
                wall_time = time.time()
                for _ in range(repeats):
                    torch.cuda.reset_peak_memory_stats()  # Reset memory stats before forward pass
                    image = torch.randn(1, 3, image_res, image_res).cuda()
                    torch.cuda.synchronize()
                    layer_time = time.time()
                    if inference_params is not None:
                        ret = model(image, inference_params)
                    else:
                        ret = model(image)
                    torch.cuda.synchronize()
                    layer_times.append(time.time() - layer_time)
                    memory_used = torch.cuda.max_memory_allocated()  # Get peak memory usage
                    memory_usages.append(memory_used)
                wall_time = time.time() - wall_time
                layer_time = sum(layer_times) / len(layer_times)
                avg_memory_used = sum(memory_usages) / len(memory_usages)
                if idx == 1:
                    y_wall_times.append(wall_time/repeats)
                    y_memory_usages.append(avg_memory_used)
                    x_image_res.append(image_res)
                print(f"image_res={image_res}, ret.shape={ret.shape}, layer_time={layer_time}, wall_time={wall_time}, memory_used={avg_memory_used / (1024 ** 2):.2f} MB")

    print("x_image_res =", x_image_res)
    print(f"y_wall_times_{mixnmatch_dim} =", y_wall_times)
    print(f"y_memory_usages_{mixnmatch_dim} =", y_memory_usages)
    

