# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detectron2/blob/main/tools/analyze_model.py

import torch
import time
import logging
import numpy as np
from collections import Counter
import tqdm
from fvcore.nn import flop_count_table  # can also try flop_count_str

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate
from detectron2.data import build_detection_test_loader, build_detection_train_loader
from detectron2.engine import default_argument_parser
from detectron2.modeling import build_model
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.analysis import (
    FlopCountAnalysis,
    activation_count_operators,
    parameter_count_table,
)
from detectron2.utils.logger import setup_logger
from detectron2.export import TracingAdapter
from torch.profiler import profile, ProfilerActivity, record_function

# fmt: off
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on

from mask2former import add_maskformer2_config, MaskFormerSemanticDatasetMapper

logger = logging.getLogger("detectron2")

"""
Analyzes FLOP count, parameter count, model structure and operator activation count for models
For usage example, please refer to tools/README.md
"""


def setup(args):
    if args.config_file.endswith(".yaml"):
        cfg = get_cfg()
        add_deeplab_config(cfg)
        add_maskformer2_config(cfg)
        cfg.merge_from_file(args.config_file)
        cfg.DATALOADER.NUM_WORKERS = 0
        cfg.merge_from_list(args.opts)
        cfg.freeze()
    else:
        cfg = LazyConfig.load(args.config_file)
        cfg = LazyConfig.apply_overrides(cfg, args.opts)
    setup_logger(name="fvcore")
    setup_logger()
    return cfg


def do_flop(cfg):
    if isinstance(cfg, CfgNode):
        if args.use_fixed_input_size:
            mapper = MaskFormerSemanticDatasetMapper(cfg, True)
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0], mapper=mapper)
        else:
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()

    counts = Counter()
    total_flops = []
    for idx, data in zip(tqdm.trange(args.num_inputs), data_loader):  # noqa
        #if args.use_fixed_input_size and isinstance(cfg, CfgNode):
        #    import torch
        #    crop_size = cfg.INPUT.CROP.SIZE[0]
        #    data[0]["image"] = torch.zeros((3, crop_size, crop_size))
        print("Image size is {}".format(data[0]["image"].shape))
        flops = FlopCountAnalysis(model, data)
        if idx > 0:
            flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False)
        counts += flops.by_operator()
        total_flops.append(flops.total())

    logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops))
    logger.info(
        "Average GFlops for each type of operators:\n"
        + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()])
    )
    logger.info(
        "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9)
    )


def do_activation(cfg):
    if isinstance(cfg, CfgNode):
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()

    counts = Counter()
    total_activations = []
    for idx, data in zip(tqdm.trange(args.num_inputs), data_loader):  # noqa
        count = activation_count_operators(model, data)
        counts += count
        total_activations.append(sum(count.values()))
    logger.info(
        "(Million) Activations for Each Type of Operators:\n"
        + str([(k, v / idx) for k, v in counts.items()])
    )
    logger.info(
        "Total (Million) Activations: {}±{}".format(
            np.mean(total_activations), np.std(total_activations)
        )
    )

def forward_memory_trade_pass(model, inputs):
    assert len(inputs) == 1, "Please use batch size=1"
    #tensor_input = inputs[0]["image"]
    #inputs = [{"image": tensor_input}]
    #if isinstance(model, (torch.nn.parallel.distributed.DistributedDataParallel, torch.nn.DataParallel)):
    #    model = model.module
    #wrapped_model = TracingAdapter(model, inputs)
    #wrapped_model.eval()

    out = model(inputs)

def do_memory(cfg):
    if isinstance(cfg, CfgNode):
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()



    for idx, data in zip(tqdm.trange(1), data_loader):  # noqa
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA],
            profile_memory=True,
            record_shapes=True,
            with_stack=True,
            with_modules=True
        ) as p:
            forward_memory_trade_pass(model, data)

    logger.info(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))


def do_peak_memory(cfg):
    if isinstance(cfg, CfgNode):
        if args.use_fixed_input_size:
            mapper = MaskFormerSemanticDatasetMapper(cfg, True)
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0], mapper=mapper)
        else:
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        if args.load_weights:
            DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        if args.load_weights:
            DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()
    print(model.device)
    total_alloc = []
    total_res = []
    for i, data in enumerate(data_loader):
        torch.cuda.reset_peak_memory_stats()
        with torch.inference_mode():
            with torch.autocast('cuda'):
                model(data)
        torch.cuda.synchronize()
        peak_alloc = torch.cuda.max_memory_allocated()
        peak_res = torch.cuda.memory_reserved()
        #logger.info("Peak Memory: {:.1f}GB".format(peak_mem / (1024**3)))
        total_alloc.append(peak_alloc)
        total_res.append(peak_res)
        if (i + 1) == args.num_inputs:
            break
    logger.info("Peak Allocated Memory: {:.1f}±{:.1f}GB".format(np.mean(total_alloc) / (1024**3), np.std(total_alloc) / (1024**3)))
    logger.info("Peak Reserved Memory: {:.1f}±{:.1f}GB".format(np.mean(total_res) / (1024**3), np.std(total_res) / (1024**3)))


def do_parameter(cfg):
    if isinstance(cfg, CfgNode):
        model = build_model(cfg)
    else:
        model = instantiate(cfg.model)
    logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5))


def do_structure(cfg):
    if isinstance(cfg, CfgNode):
        model = build_model(cfg)
    else:
        model = instantiate(cfg.model)
    logger.info("Model Structure:\n" + str(model))

def do_fps(cfg):
    torch.backends.cudnn.benchmark = True
    #torch.backends.cuda.matmul.allow_tf32 = True
    #torch.backends.cudnn.allow_tf32 = True
    if isinstance(cfg, CfgNode):
        if args.use_fixed_input_size:
            mapper = MaskFormerSemanticDatasetMapper(cfg, True)
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0], mapper=mapper)
        else:
            data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()
    #model = model.to(memory_format=torch.channels_last).eval()
    # the first several iterations may be very slow so skip them


    num_warmup = 50
    pure_inf_time = 0
    total_iters = 250
    log_interval = 25
    device = "cuda"
    sort_by_keyword = device + "_time_total"
    # benchmark with 200 image and take the average
    for i, data in enumerate(data_loader):
    #for i in range(total_iters):
        #crop_size = cfg.INPUT.CROP.SIZE[0]
        #data = [{"image": torch.zeros((3, crop_size, crop_size))}]
        #torch.cuda.synchronize()
        start_time = time.perf_counter()

        #with torch.inference_mode():
        #        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        #            with record_function("model_inference"):
        #with torch.autocast('cuda'):
        model(data)
        #print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
        #torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time

        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                logger.info(f'Done image [{i + 1:<3}/ {total_iters}], '
                      f'fps: {fps:.2f} img / s')

        if (i + 1) == total_iters:
            fps = (i + 1 - num_warmup) / pure_inf_time
            logger.info(f'Overall fps: {fps:.2f} img / s')
            break


if __name__ == "__main__":
    parser = default_argument_parser(
        epilog="""
Examples:
To show parameters of a model:
$ ./analyze_model.py --tasks parameter \\
    --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
Flops and activations are data-dependent, therefore inputs and model weights
are needed to count them:
$ ./analyze_model.py --num-inputs 100 --tasks flop \\
    --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\
    MODEL.WEIGHTS /path/to/model.pkl
"""
    )
    parser.add_argument(
        "--tasks",
        choices=["flop", "activation", "parameter", "structure", "memory", "pmemory", "fps"],
        required=True,
        nargs="+",
    )
    parser.add_argument(
        "-n",
        "--num-inputs",
        default=100,
        type=int,
        help="number of inputs used to compute statistics for flops/activations, "
        "both are data dependent.",
    )
    parser.add_argument(
        "--use-fixed-input-size",
        action="store_true",
        help="use fixed input size when calculating flops",
    )
    parser.add_argument(
        "--load-weights",
        action="store_true",
        help="If trained weights should be loaded",
    )
    args = parser.parse_args()
    assert not args.eval_only
    assert args.num_gpus == 1

    cfg = setup(args)

    for task in args.tasks:
        {
            "flop": do_flop,
            "activation": do_activation,
            "parameter": do_parameter,
            "structure": do_structure,
            "fps": do_fps,
            "pmemory": do_peak_memory,
        }[task](cfg)
