from pathlib import Path
import sys
import time
from torch import Tensor
from typing import List
import matplotlib.pyplot as plt
sys.path.append('..')

import torch
from modeling.mamba2.modeling_mamba2_dao import Mamba2ForCausalLM
from transformers import AutoTokenizer


def get_long_prompt():
    return 'My name is John, and I like eating donuts. ' * 2000


@torch.no_grad()
def get_stats(model: Mamba2ForCausalLM, input_ids: Tensor, chunk_size: int = 128, n_layers: int = 48):
    cur_state = None
    # print(input_ids)
    ssm_stats = {i: [] for i in range(n_layers)}
    conv_stats = {i: [] for i in range(n_layers)}
    states = []
    seqlen = len(input_ids)
    print(f"Seq len: {seqlen}")
    state_size = 0
    input_ids = input_ids[:1].unsqueeze(0)
    output = model(input_ids, states=cur_state)
    logits = output.logits
    cur_state = output.states
    for layer_i in range(n_layers):
        conv_state = cur_state[layer_i][0]
        ssm_state = cur_state[layer_i][1]
        state_size += conv_state.numel()
        state_size += ssm_state.numel()
    return state_size


def main():
    model_size = '130m'
    pretrained_path = f'../../ckpts/mamba/mamba2-{model_size}'
    device = 'cuda'
    dtype = torch.float16
    n_layers = {
        '130m': 24,
        '370m': 48,
        '780m': 48,
    }[model_size]
    chunk_size = 512
    print(f"Loading from {pretrained_path}")
    model = Mamba2ForCausalLM.from_pretrained(pretrained_path).to(dtype=dtype, device=device)
    print(model.backbone.layers[0].mixer.nheads)
    print(model.backbone.layers[0].mixer.headdim)
    print(model.backbone.layers[0].mixer.d_ssm)
    print(model.backbone.layers[0].mixer.d_inner)
    print(model.backbone.layers[0].mixer.d_state)
    print(model.backbone.layers[0].mixer.ngroups)


if __name__ == '__main__':
    main()
