import torch.nn as nn
from typing import Optional, Union
from transformers import PreTrainedModel
import logging


class ModelSummary:
    """
    A static utility class to summarize PyTorch or HuggingFace models.

    Provides:
        - Total parameters
        - Trainable and frozen parameters
        - Percent breakdown
        - Estimated model size for FP32, FP16, BF16
        - Number of layers
        - Optional architecture printout

    Example:
    --------
    >>> from transformers import BertModel
    >>> model = BertModel.from_pretrained("bert-base-uncased")

    >>> # Either verbose
    >>> ModelSummary.summarize(model, verbose=True)

    >>> # Or with custom logger
    >>> import logging
    >>> logger = logging.getLogger("MyLogger")
    >>> logger.setLevel(logging.INFO)
    >>> logger.addHandler(logging.StreamHandler())
    >>> # Prevent log messages from being passed to the root logger
    >>> logger.propagate = False
    >>> ModelSummary.summarize(model, logger=logger, verbose=False)
    """

    @staticmethod
    def summarize(
        model: Union[nn.Module, PreTrainedModel],
        model_name: Optional[str] = None,
        logger: Optional[logging.Logger] = None,
        verbose: bool = False,
        print_architecture: bool = False
    ) -> dict:
        model_name = model_name or model.__class__.__name__

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        num_layers = sum(1 for _ in model.modules())

        trainable_pct = 100.0 * trainable_params / total_params if total_params > 0 else 0
        frozen_pct = 100.0 * frozen_params / total_params if total_params > 0 else 0

        model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
        model_size_fp32_mb = model_size_bytes / (1024 ** 2)
        model_size_fp16_mb = model_size_fp32_mb / 2
        model_size_bf16_mb = model_size_fp16_mb

        summary = {
            "Model": model_name,
            "Total parameters": total_params,
            "Trainable parameters": trainable_params,
            "Frozen parameters": frozen_params,
            "Number of layers": num_layers,
            "Estimated param size (MB, FP32)": model_size_fp32_mb,
            "Estimated param size (MB, FP16)": model_size_fp16_mb,
            "Estimated param size (MB, BF16)": model_size_bf16_mb,
            "Trainable %": trainable_pct,
            "Frozen %": frozen_pct,
        }

        if logger or verbose:
            ModelSummary._log_summary(summary, logger=logger, verbose=verbose)

        if print_architecture:
            ModelSummary._output(f"\n🧱 Model Architecture for {model_name}:\n{model}", logger, verbose)

        return summary

    @staticmethod
    def _log_summary(summary: dict, logger: Optional[logging.Logger], verbose: bool):
        ModelSummary._output("📊 Model Summary:", logger, verbose)
        ModelSummary._output(f"  Model: {summary['Model']}", logger, verbose)
        ModelSummary._output(f"  Total parameters: {summary['Total parameters']:,}", logger, verbose)
        ModelSummary._output(f"  Trainable parameters: {summary['Trainable parameters']:,} ({summary['Trainable %']:.2f}%)", logger, verbose)
        ModelSummary._output(f"  Frozen parameters: {summary['Frozen parameters']:,} ({summary['Frozen %']:.2f}%)", logger, verbose)
        ModelSummary._output(f"  Number of layers: {summary['Number of layers']:,}", logger, verbose)

        ModelSummary._output("💾 Estimated Model Parameter Size:", logger, verbose)
        ModelSummary._output(f"  FP32: {summary['Estimated param size (MB, FP32)']:,.2f} MB", logger, verbose)
        ModelSummary._output(f"  FP16: {summary['Estimated param size (MB, FP16)']:,.2f} MB", logger, verbose)
        ModelSummary._output(f"  BF16: {summary['Estimated param size (MB, BF16)']:,.2f} MB", logger, verbose)

    @staticmethod
    def _output(msg: str, logger: Optional[logging.Logger], verbose: bool):
        if logger:
            logger.info(msg)
        elif verbose:
            print(msg)
