import humanfriendly
import numpy as np
import torch


def get_human_readable_count(number: int) -> str:
    """Return human_readable_count

    Originated from:
    https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py

    Abbreviates an integer number with K, M, B, T for thousands, millions,
    billions and trillions, respectively.
    Examples:
        >>> get_human_readable_count(123)
        '123  '
        >>> get_human_readable_count(1234)  # (one thousand)
        '1 K'
        >>> get_human_readable_count(2e6)   # (two million)
        '2 M'
        >>> get_human_readable_count(3e9)   # (three billion)
        '3 B'
        >>> get_human_readable_count(4e12)  # (four trillion)
        '4 T'
        >>> get_human_readable_count(5e15)  # (more than trillion)
        '5,000 T'
    Args:
        number: a positive integer number
    Return:
        A string formatted according to the pattern described above.
    """
    assert number >= 0
    labels = [" ", "K", "M", "B", "T"]
    num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
    num_groups = int(np.ceil(num_digits / 3))
    num_groups = min(num_groups, len(labels))  # don't abbreviate beyond trillions
    shift = -3 * (num_groups - 1)
    number = number * (10**shift)
    index = num_groups - 1
    return f"{number:.2f} {labels[index]}"


def to_bytes(dtype) -> int:
    # torch.float16 -> 16
    return int(str(dtype)[-2:]) // 8


def model_summary(model: torch.nn.Module) -> str:
    message = "Model structure:\n"
    message += str(model)
    tot_params = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in model.lm.model.lm_head.parameters())
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + sum(p.numel() for p in model.lm.model.lm_head.parameters() if p.requires_grad)
    percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
    tot_params = get_human_readable_count(tot_params)
    num_params = get_human_readable_count(num_params)
    message += "\n\nAnalysis model:\n"
    # import pdb; pdb.set_trace()
    for name, param in model.lm.model.lm_head.named_parameters():
        if param.requires_grad:
            message += f"Name: {name}, Total Size: {param.numel()} \n"
    for name, param in model.named_parameters():
        if param.requires_grad:
            message += f"Name: {name}, Total Size: {param.numel()} \n"
        # else:
        #     message += f"Name: {name}, Total Size: {param.numel()} \n"

    # for p in model.parameters():
    #     if p.requires_grad == True:

    #         message += f"\n  {p}  {p.numel()} ({p.numel() * to_bytes(p.dtype)} bytes) trainable"

    message += "\n\nModel summary:\n"
    message += f"    Class Name: {model.__class__.__name__}\n"
    message += f"    Total Number of model parameters: {tot_params}\n"
    message += (
        f"    Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
    )
    num_bytes = humanfriendly.format_size(
        sum(
            p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
        )
    )
    message += f"    Size: {num_bytes}\n"
    dtype = next(iter(model.parameters())).dtype
    message += f"    Type: {dtype}"
    return message
