import argparse
import torch

from torchinfo import summary
from transformers import AutoModel
from typing import Tuple


def get_model_info(
    model_name: str,
    input_size: Tuple,
):
    """ Get the model information.

    Args:
        model_name (str, optional): The model name.
        input_size (Tuple, optional): The input size of the given model.
    """

    model = AutoModel.from_pretrained(pretrained_model_name_or_path=model_name)
    # model.to('cuda' if torch.cuda.is_available() else 'cpu')

    summary(
        model=model,
        input_size=input_size,
        dtypes=[torch.int],
        # depth=5,
    )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-mn',
        '--model_name',
        default='google-bert/bert-base-uncased',
        type=str,
        help='The model name.',
    )
    parser.add_argument(
        '-is',
        '--input_size',
        default='1,512',
        type=str,
        help='The input size of the given model.',
    )

    args = parser.parse_args()

    model_name = args.model_name
    input_size = tuple(map(
        int,
        args.input_size.split(','),
    ))

    get_model_info(
        model_name=model_name,
        input_size=input_size,
    )
