import torch
from typing import Dict, Optional


# ============================================================
# Utilities
# ============================================================

def count_params(module: torch.nn.Module) -> int:
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


def format_params(n: int) -> str:
    if n >= 1e9:
        return f"{n / 1e9:.2f}B"
    if n >= 1e6:
        return f"{n / 1e6:.2f}M"
    if n >= 1e3:
        return f"{n / 1e3:.2f}K"
    return str(n)


# ============================================================
# 1. Empirical inspection (real model)
# ============================================================

def collect_empirical_report(model) -> Dict:
    module = model.module if hasattr(model, "module") else model

    report = {
        "source": "empirical",
        "total_params": 0,
        "embeddings_params": 0,
        "attention_params": 0,
        "router_params": 0,
        "expert_params": 0,
        "experts_total": 0,
        "layers": [],
    }

    for layer_idx, layer in enumerate(module.decoder.layers):

        layer_report = {
            "layer": layer_idx,
            "self_attention": 0,
            "cross_attention": 0,
            "router": 0,
            "experts": 0,
            "num_experts": 0,
            "params_per_expert": 0,
            "total": 0,
        }

        # ---- Attention ----
        for attn_name in ("self_attention", "cross_attention"):
            attn = getattr(layer, attn_name, None)
            if attn is not None:
                n = count_params(attn)
                layer_report[attn_name] = n
                layer_report["total"] += n
                report["attention_params"] += n

        # ---- MLP / MoE ----
        mlp = getattr(layer, "mlp", None)
        if mlp is not None:

            # Router
            if hasattr(mlp, "router") and mlp.router is not None:
                n = count_params(mlp.router)
                layer_report["router"] = n
                layer_report["total"] += n
                report["router_params"] += n

            # Experts
            if hasattr(mlp, "experts") and mlp.experts is not None:
                experts = mlp.experts

                if experts.__class__.__name__ == "TEGroupedMLP":
                    total = count_params(experts)
                    n_exp = experts.num_local_experts
                    per_exp = total // n_exp

                elif experts.__class__.__name__ == "SequentialMLP":
                    per_exp = count_params(experts.local_experts[0])
                    n_exp = len(experts.local_experts)
                    total = per_exp * n_exp

                else:
                    raise TypeError(f"Unknown expert type: {type(experts)}")

                layer_report.update({
                    "experts": total,
                    "num_experts": n_exp,
                    "params_per_expert": per_exp,
                })

                layer_report["total"] += total
                report["expert_params"] += total
                report["experts_total"] += n_exp

        report["layers"].append(layer_report)
        report["total_params"] += layer_report["total"]

    report["avg_params_per_expert"] = (
        report["expert_params"] // report["experts_total"]
        if report["experts_total"] > 0 else 0
    )

    return report


# ============================================================
# 2. Analytical estimation (args-based)
# ============================================================

def collect_analytical_report(args) -> Dict:
    d = args.hidden_size
    L = args.num_layers

    # ---- Attention ----
    q_proj = args.kv_channels * args.num_attention_heads
    q_ratio = q_proj / d

    num_q_groups = (
        args.num_attention_heads
        if not args.group_query_attention
        else args.num_query_groups
    )

    attn_factor = (1 + num_q_groups / args.num_attention_heads) * q_ratio
    attn_per_layer = 2 * d * d * attn_factor

    # ---- MLP / MoE ----
    num_experts = 1 if args.num_experts is None else args.num_experts
    gated = 1.5 if args.swiglu else 1.0

    mlp_factor = (args.ffn_hidden_size / d) * num_experts * gated
    mlp_per_layer = 2 * d * d * mlp_factor

    # ---- LayerNorm ----
    ln_per_layer = 4 * d
    final_ln = 2 * d

    # ---- Totals ----
    attention_total = attn_per_layer * L
    mlp_total = mlp_per_layer * L
    ln_total = ln_per_layer * L + final_ln

    transformer_total = attention_total + mlp_total + ln_total

    # ---- Embeddings ----
    embed = d * args.padded_vocab_size
    if args.untie_embeddings_and_output_weights:
        embed *= 2

    total = transformer_total + embed

    return {
        "source": "analytical",
        "total_params": int(total),
        "embeddings_params": int(embed),
        "attention_params": int(attention_total),
        "router_params": 0,
        "expert_params": int(mlp_total),
        "experts_total": num_experts * L,
        "avg_params_per_expert": int(mlp_per_layer / num_experts),
        "layers": [
            {
                "layer": i,
                "self_attention": int(attn_per_layer),
                "cross_attention": 0,
                "router": 0,
                "experts": int(mlp_per_layer),
                "num_experts": num_experts,
                "params_per_expert": int(mlp_per_layer / num_experts),
                "total": int(attn_per_layer + mlp_per_layer + ln_per_layer),
            }
            for i in range(L)
        ],
    }


# ============================================================
# 3. Unified printer
# ============================================================

def print_model_report(report: Dict):
    print("\n================ MODEL PARAMETER REPORT ================\n")
    print(f"Source                   : {report['source']}")
    print(f"Total params             : {format_params(report['total_params'])}")
    print(f"Embedding params         : {format_params(report['embeddings_params'])}")
    print(f"Attention params         : {format_params(report['attention_params'])}")
    print(f"Router params            : {format_params(report['router_params'])}")
    print(f"Expert params            : {format_params(report['expert_params'])}")
    print(f"Total experts            : {report['experts_total']}")
    print(f"Params per expert        : {format_params(report['avg_params_per_expert'])}")

    print("\n---------------- Per-layer breakdown ----------------\n")

    for l in report["layers"]:
        print(
            f"Layer {l['layer']:02d} | "
            f"SA={format_params(l['self_attention'])}, "
            f"CA={format_params(l['cross_attention'])}, "
            f"router={format_params(l['router'])}, "
            f"experts={l['num_experts']} × {format_params(l['params_per_expert'])}, "
            f"TOTAL={format_params(l['total'])}"
        )

    print("\n=======================================================\n")