import torch
import torch.nn as nn
from typing import Optional, Union
from transformers import PreTrainedModel

import logging
class ModelUtils:
    @staticmethod
    def check_weights_match(model, state_dict, verbose=True):
        """
        Checks if the given state_dict matches the model's current weights exactly.
        Raises an error if the model and state_dict do not match structurally.

        Args:
            model (torch.nn.Module): The model to compare weights for.
            state_dict (dict): The state_dict to compare against.
            verbose (bool): If True, prints which parameters differ and summary.

        Returns:
            match (bool): True if all weights match exactly, False otherwise.
        """
        model_keys = set(model.state_dict().keys())
        state_dict_keys = set(state_dict.keys())

        # Structural check: keys must match
        if model_keys != state_dict_keys:
            missing = model_keys - state_dict_keys
            extra = state_dict_keys - model_keys
            raise ValueError(
                f"Model and state_dict do not match structurally.\n"
                f"Missing keys in state_dict: {missing}\n"
                f"Unexpected keys in state_dict: {extra}"
            )

        # Compare values
        match = True
        for name, param_current in model.state_dict().items():
            param_loaded = state_dict[name]
            if not torch.equal(param_current, param_loaded):
                match = False
                if verbose:
                    print(f"🔄 Parameter '{name}' differs between model and state_dict.")

        if verbose:
            if match:
                print("✅ Input state_dict matches the model's current weights exactly.")
            else:
                print("⚠️ Input state_dict is structurally compatible but has differing parameter values.")

        return match



class ModelSummary:
    """
    A static utility class to summarize PyTorch or HuggingFace models.

    Provides:
        - Total parameters
        - Trainable and frozen parameters
        - Percent breakdown
        - Estimated model size for FP32, FP16, BF16
        - Number of layers
        - Optional architecture printout

    Example:
    --------
    >>> from transformers import BertModel
    >>> model = BertModel.from_pretrained("bert-base-uncased")

    >>> # Either verbose
    >>> ModelSummary.summarize(model, verbose=True)

    >>> # Or with custom logger
    >>> import logging
    >>> logger = logging.getLogger("MyLogger")
    >>> logger.setLevel(logging.INFO)
    >>> logger.addHandler(logging.StreamHandler())
    >>> # Prevent log messages from being passed to the root logger
    >>> logger.propagate = False
    >>> ModelSummary.summarize(model, logger=logger, verbose=False)
    """

    @staticmethod
    def summarize(
        model: Union[nn.Module, PreTrainedModel],
        model_name: Optional[str] = None,
        logger: Optional[logging.Logger] = None,
        verbose: bool = False,
        print_architecture: bool = False
    ) -> dict:
        model_name = model_name or model.__class__.__name__

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        num_layers = sum(1 for _ in model.modules())

        trainable_pct = 100.0 * trainable_params / total_params if total_params > 0 else 0
        frozen_pct = 100.0 * frozen_params / total_params if total_params > 0 else 0

        model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
        model_size_fp32_mb = model_size_bytes / (1024 ** 2)
        model_size_fp16_mb = model_size_fp32_mb / 2
        model_size_bf16_mb = model_size_fp16_mb

        summary = {
            "Model": model_name,
            "Total parameters": total_params,
            "Trainable parameters": trainable_params,
            "Frozen parameters": frozen_params,
            "Number of layers": num_layers,
            "Estimated param size (MB, FP32)": model_size_fp32_mb,
            "Estimated param size (MB, FP16)": model_size_fp16_mb,
            "Estimated param size (MB, BF16)": model_size_bf16_mb,
            "Trainable %": trainable_pct,
            "Frozen %": frozen_pct,
        }

        if logger or verbose:
            ModelSummary._log_summary(summary, logger=logger, verbose=verbose)

        if print_architecture:
            ModelSummary._output(f"\n🧱 Model Architecture for {model_name}:\n{model}", logger, verbose)

        return summary

    @staticmethod
    def _log_summary(summary: dict, logger: Optional[logging.Logger], verbose: bool):
        ModelSummary._output("📊 Model Summary:", logger, verbose)
        ModelSummary._output(f"  Model: {summary['Model']}", logger, verbose)
        ModelSummary._output(f"  Total parameters: {summary['Total parameters']:,}", logger, verbose)
        ModelSummary._output(f"  Trainable parameters: {summary['Trainable parameters']:,} ({summary['Trainable %']:.2f}%)", logger, verbose)
        ModelSummary._output(f"  Frozen parameters: {summary['Frozen parameters']:,} ({summary['Frozen %']:.2f}%)", logger, verbose)
        ModelSummary._output(f"  Number of layers: {summary['Number of layers']:,}", logger, verbose)

        ModelSummary._output("💾 Estimated Model Parameter Size:", logger, verbose)
        ModelSummary._output(f"  FP32: {summary['Estimated param size (MB, FP32)']:,.2f} MB", logger, verbose)
        ModelSummary._output(f"  FP16: {summary['Estimated param size (MB, FP16)']:,.2f} MB", logger, verbose)
        ModelSummary._output(f"  BF16: {summary['Estimated param size (MB, BF16)']:,.2f} MB", logger, verbose)

    @staticmethod
    def _output(msg: str, logger: Optional[logging.Logger], verbose: bool):
        if logger:
            logger.info(msg)
        elif verbose:
            print(msg)