import os
import sys
import torch

# ---- fix import path ----
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))   # tools/
_PROJ_ROOT = os.path.dirname(_THIS_DIR)                  # Pointcept/
if _PROJ_ROOT not in sys.path:
    sys.path.insert(0, _PROJ_ROOT)

from pointcept.engines.defaults import (
    default_argument_parser,
    default_config_parser,
)
from pointcept.models import build_model


def count_params(model):
    return sum(p.numel() for p in model.parameters())


def main():
    # ✅ 用 Pointcept 自己的 parser（关键！）
    args = default_argument_parser().parse_args()

    # 1. load config (.py + _base_)
    cfg = default_config_parser(args.config_file, args.options)

    # 2. 只拿 backbone
    backbone_cfg = cfg.model.backbone

    print(f"[Config] {args.config_file}")
    print(f"[Backbone] {backbone_cfg.type}")

    # 3. build model
    model = build_model(backbone_cfg).cuda()
    model.eval()

    # 4. dummy input
    N = 100000
    in_ch = getattr(backbone_cfg, "in_channels", 4)

    dummy = {
        "coord": torch.randn(N, 3).cuda(),
        "feat": torch.randn(N, in_ch).cuda(),
        "offset": torch.IntTensor([N]).cuda(),
    }

    # 5. params
    params = count_params(model)

    # 6. FLOPs
    try:
        from fvcore.nn import FlopCountAnalysis
        with torch.no_grad():
            flops = FlopCountAnalysis(model, (dummy,))
        flops_g = flops.total() / 1e9
    except Exception as e:
        print("[WARN] FLOPs failed:", e)
        flops_g = None

    print("------------------------------------------------")
    if flops_g is not None:
        print(f"FLOPs : {flops_g:.3f} G")
    else:
        print("FLOPs : N/A")
    print(f"Params: {params/1e6:.3f} M")
    print("------------------------------------------------")


if __name__ == "__main__":
    main()