#!/usr/bin/env python3
"""
Measure parameter size and FLOPs for Mask2Former.

Usage examples
--------------
# ADE20K, R50:
python mask2former_flops_params.py \
  --config projects/Mask2Former/configs/ade20k/semantic-segmentation/maskformer2_R50_bs16_160k.yaml \
  --height 512 --width 512 --batch-size 1 --device cuda
"""
import os
import argparse
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

# --- Detectron2 / Mask2Former imports ---
from detectron2.config import get_cfg
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2 import model_zoo

# Mask2Former project adds registries / meta-arch
from Mask2Former.mask2former import add_maskformer2_config


# ---------------------------
# Custom FLOP handlers
# ---------------------------
def register_attention_handles(fca: FlopCountAnalysis, model):

    def _mha_flops(module, inputs, outputs):
        return None

    for m in model.modules():
        if isinstance(m, nn.MultiheadAttention):
            fca.set_op_handle(type(m), _mha_flops)

    MSDA_CLASS_NAMES = {
        "MSDeformAttn",
        "MultiScaleDeformableAttention",
        "MultiScaleDeformableAttentionModule",
    }

    def _ms_deform_attn_handle(module, inputs, outputs):
        H = getattr(module, "num_heads", None)
        C = getattr(module, "embed_dim", None) or getattr(module, "d_model", None)
        n_levels = getattr(module, "num_levels", None)
        n_points = getattr(module, "num_points", None)

        x_value = inputs[0]
        sampling_locations = inputs[3]
        attention_weights = inputs[4]

        B = x_value.shape[0]
        Lq = sampling_locations.shape[1]

        if H is None: H = sampling_locations.shape[2]
        if n_levels is None: n_levels = sampling_locations.shape[3]
        if n_points is None: n_points = sampling_locations.shape[4]
        if C is None: C = x_value.shape[-1]

        flops_core = 2.0 * B * H * Lq * n_levels * n_points * (C // H)

        Lv_total = x_value.shape[1]
        flops_linear = 0.0  

        return flops_core + flops_linear

    for m in model.modules():
        if type(m).__name__ in MSDA_CLASS_NAMES:
            fca.set_op_handle(type(m), _ms_deform_attn_handle)


def build_from_config(cfg_path, weights=None):
    cfg = get_cfg()
    add_maskformer2_config(cfg)
    if cfg_path.startswith("detectron2://") or cfg_path.startswith("COCO-InstanceSegmentation"):
        cfg.merge_from_file(model_zoo.get_config_file(cfg_path))
    else:
        cfg.merge_from_file(cfg_path)

    if weights:
        cfg.MODEL.WEIGHTS = weights

    cfg.MODEL.DEVICE = "cpu" 
    model = build_model(cfg)
    model.eval()

    if cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)

    return cfg, model


# Dummy input for forward
def make_batched_inputs(batch_size, height, width, device):
    inputs = []
    for _ in range(batch_size):
        img = torch.randn(3, height, width, device=device)
        inputs.append({"image": img, "height": height, "width": width})
    return inputs


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True, help="Path to Mask2Former config (.yaml)")
    ap.add_argument("--weights", default="", help="Optional checkpoint path")
    ap.add_argument("--height", type=int, default=512)
    ap.add_argument("--width", type=int, default=512)
    ap.add_argument("--batch-size", type=int, default=1)
    ap.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
    args = ap.parse_args()

    cfg, model = build_from_config(args.config, args.weights)
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    model.to(device)

    # --- Parameter ---
    n_params = sum(p.numel() for p in model.parameters())
    n_params_m = n_params / 1e6

    # --- FLOPs ---
    with torch.no_grad():
        batched_inputs = make_batched_inputs(args.batch_size, args.height, args.width, device)
        fca = FlopCountAnalysis(model, (batched_inputs,))
        register_attention_handles(fca, model)
        flops_total = fca.total()
        print(flop_count_table(fca, max_depth=4))

    gmacs = flops_total / 1e9
    print(f"\nParams: {n_params_m:.3f} M")
    print(f"FLOPs:  {gmacs:.3f} GFLOPs  (for input {args.batch_size}x3x{args.height}x{args.width})")


if __name__ == "__main__":
    main()