import sys
from matmamba.mamba2 import MatMamba2
from matmamba.mixer_seq_simple import MatMambaLMHeadModel

from mamba_ssm.models.config_mamba import MambaConfig

if __name__ == "__main__":

    if len(sys.argv) != 2:
        model_str = "130m"
    else:
        model_str = sys.argv[1]

    model_config = {
        "130m": MambaConfig(n_layer=24, d_model=768),
        "370m": MambaConfig(n_layer=48, d_model=1024),
        "790m": MambaConfig(n_layer=48, d_model=1536),
        "1.4b": MambaConfig(n_layer=48, d_model=2048),
        "2.8b": MambaConfig(n_layer=64, d_model=2560),
    }[model_str]
    model = MatMambaLMHeadModel(model_config)

    raw_model = model
    print(raw_model)

    param_dict = {pn: p for pn, p in raw_model.named_parameters()}
    # filter out those that do not require grad
    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}

    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
    num_decay_params = sum(p.numel() for p in decay_params)
    num_nodecay_params = sum(p.numel() for p in nodecay_params)
    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

    # Find number of embedding params in model.backbone.embedding
    embedding_params = [p for n, p in raw_model.backbone.embedding.named_parameters()]
    num_embedding_params = sum(p.numel() for p in embedding_params)
    print(f"num embedding parameter tensors: {len(embedding_params)}, with {num_embedding_params:,} parameters")

    # Find number of non-embedding params in model
    non_embedding_params = [p for n, p in raw_model.named_parameters() if "embedding" not in n]
    num_non_embedding_params = sum(p.numel() for p in non_embedding_params)
    print(f"num non-embedding parameter tensors: {len(non_embedding_params)}, with {num_non_embedding_params:,} parameters")

    # Print sum of non-embedding and embedding params
    print(f"Sum of non-embedding and embedding parameters: {num_non_embedding_params + num_embedding_params:,}")