from numbers import Number
import numpy as np
import warnings

import torch
from torch.jit import TracerWarning
from fvcore.nn import FlopCountAnalysis
from fvcore.common.checkpoint import _named_modules_with_dup
from fvcore.nn.jit_analysis import _get_scoped_trace_graph
from fvcore.nn.jit_handles import get_shape

from sonic_conv import SonicConv2d, enable_speedtest, disable_speedtest
from utils import no_train


@no_train
@torch.no_grad()
def get_flops(model, input_shape, verbose=True):
    children_iter = model.modules()
    m = next(children_iter)
    while not hasattr(m, "weight"):
        m = next(children_iter)
    w = m.weight
    dtype, device = w.dtype, w.device
    if verbose:
        print("=" * 30 + " FLOPS " + "=" * 30)
    flops = FlopCountAnalysis(model, (torch.rand(1, *input_shape).to(dtype).to(device),))

    sonic_conv_names = set()
    for _, mod in _named_modules_with_dup(model):
        if isinstance(mod, SonicConv2d):
            sonic_conv_names.add(flops._aliases[mod])

    total = 0
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=TracerWarning)
        graph = _get_scoped_trace_graph(flops._model, flops._inputs, flops._aliases)
    for node in graph.nodes():
        kind = node.kind()
        if kind == "prim::PythonOp":
            kind = kind + "." + node.pyname()
        scope_name = node.scopeName().split("/")[-1]
        if scope_name in sonic_conv_names:
            continue

        if kind in flops._op_handles:
            inputs, outputs = list(node.inputs()), list(node.outputs())
            op_counts = flops._op_handles[kind](inputs, outputs)
            if isinstance(op_counts, Number):
                op_counts = int(op_counts)
                total += op_counts
            else:
                op_counts = {k: int(v) for k, v in op_counts.items()}
                total += sum(op_counts.values())

            if verbose:
                if isinstance(op_counts, Number):
                    _flops = op_counts
                else:
                    _flops = ", ".join(f"{k} {v}" for k, v in op_counts.items())
                if kind == "aten::einsum":
                    _inputs = [f"\"{inputs[0].toIValue()}\""] + [str(get_shape(v)) for v in inputs[1].node().inputs()]
                else:
                    if kind in ["aten::_convolution", "aten::linear"]:
                        main, optional = [0, 1], [2]
                    elif kind in ["aten::batch_norm", "aten::instance_norm"]:
                        main, optional = [0], [1]
                    elif kind in ["aten::group_norm", "aten::layer_norm"]:
                        main, optional = [0], [2]
                    elif kind in ["aten::upsample_nearest2d", "aten::upsample_bilinear2d", "aten::adaptive_avg_pool2d", "aten::grid_sampler"]:
                        main, optional = [0], []
                    else:
                        main, optional = [0, 1], []
                    _inputs = [str(get_shape(inputs[i])) for i in main] + [str(get_shape(inputs[i])) for i in optional if get_shape(inputs[i]) is not None]
                print(f"{kind}({', '.join(_inputs)}) -> {get_shape(outputs[0])}, flops: {_flops}")
    if verbose:
        print("Flops:", total)
    return total


@no_train
@torch.no_grad()
def get_inference_time(
        model,
        input_shape,
        batch_sizes=[2048, 4096, 8192],
        repeats=100,
        verbose=True,
    ):
    children_iter = model.modules()
    m = next(children_iter)
    while not hasattr(m, "weight"):
        m = next(children_iter)
    w = m.weight
    dtype, device = w.dtype, w.device

    if verbose:
        print("=" * 30 + " Inference Time " + "=" * 30)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    inference_times = []
    for batch_size in batch_sizes:
        enable_speedtest(model)
        _ = model(torch.empty(batch_size, *input_shape, device=device, dtype=dtype))
        times = []
        for i in range(repeats):
            inputs = torch.empty(batch_size, *input_shape, device=device, dtype=dtype)
            start.record()
            torch.cuda.synchronize()
            _ = model(inputs)
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))
        inference_time = np.mean(times)
        inference_times.append(inference_time)
        if verbose:
            print(f"Batch size {batch_size}, inference time {inference_time:0.3f}ms")
        disable_speedtest(model)
    return inference_times