# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import pytest
from onnx import TensorProto, load_model, numpy_helper
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info

from sparseml.onnx.utils import (
    NodeParam,
    SparsityMeasurement,
    calculate_flops,
    check_load_model,
    conv_node_params,
    extract_node_id,
    extract_node_shapes,
    extract_shape,
    gemm_node_params,
    get_init_by_name,
    get_kernel_shape,
    get_node_attributes,
    get_node_by_id,
    get_node_input_nodes,
    get_node_inputs,
    get_node_output_nodes,
    get_node_outputs,
    get_node_params,
    get_nodes_by_input_id,
    get_nodes_by_output_id,
    get_numpy_dtype,
    get_prunable_node_from_foldable,
    get_prunable_nodes,
    is_foldable_node,
    is_prunable_node,
    matmul_node_params,
    model_inputs,
    model_outputs,
    onnx_nodes_sparsities,
)
from sparsezoo import Zoo


from tests.sparseml.onnx.helpers import (  # noqa isort: skip
    extract_node_models,
    onnx_repo_models,
)


@pytest.fixture
def simple_onnx_model():
    X = make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])

    Z = make_tensor_value_info("Z", TensorProto.FLOAT, [3, 4])

    node_defs = [
        make_node("Conv", ["X", "node1.weight", "node1.bias"], ["Y"], name="node1"),
        make_node("ReLU", ["Y"], ["Z"], name="node2"),
    ]

    graph_def = make_graph(
        node_defs,
        "test-model",
        [X],
        [Z],
        initializer=[
            numpy_helper.from_array(
                numpy.random.randn(2, 3, 3, 3), name="node1.weight"
            ),
            numpy_helper.from_array(numpy.random.randn(3), name="node1.bias"),
        ],
    )

    model_def = make_model(graph_def)
    return model_def


@pytest.fixture
def foldable_onnx_model():
    X = make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])

    Z = make_tensor_value_info("Z", TensorProto.FLOAT, [3, 4])

    node_defs = [
        make_node("Conv", ["X", "node1.weight", "node1.bias"], ["Y"], name="node1"),
        make_node("batchnormalization", ["Y"], ["Z"], name="node2"),
    ]

    graph_def = make_graph(
        node_defs,
        "test-model",
        [X],
        [Z],
        initializer=[
            numpy_helper.from_array(
                numpy.random.randn(2, 3, 3, 3), name="node1.weight"
            ),
            numpy_helper.from_array(numpy.random.randn(3), name="node1.bias"),
        ],
    )

    model_def = make_model(graph_def)
    return model_def


def get_prunable_onnx_model():
    X = make_tensor_value_info("A", TensorProto.FLOAT, [3, 2])

    Z = make_tensor_value_info("Z", TensorProto.FLOAT, [4])

    node_defs = [
        make_node("Conv", ["A", "node1.weight", "node1.bias"], ["B"], name="node1"),
        make_node("Gemm", ["B", "node2.weight"], ["C"], name="node2"),
        make_node("MatMul", ["C", "node3.weight"], ["D"], name="node3"),
        make_node("ReLU", ["D"], ["Z"], name="node4"),
    ]

    graph_def = make_graph(
        node_defs,
        "test-model",
        [X],
        [Z],
        initializer=[
            numpy_helper.from_array(
                numpy.random.randn(2, 3, 3, 3), name="node1.weight"
            ),
            numpy_helper.from_array(numpy.random.randn(2), name="node1.bias"),
            numpy_helper.from_array(numpy.random.randn(12, 3), name="node2.weight"),
            numpy_helper.from_array(numpy.random.randn(3, 4), name="node3.weight"),
        ],
    )

    model_def = make_model(graph_def)
    return model_def


@pytest.fixture
def prunable_onnx_model():
    return get_prunable_onnx_model()


def test_check_load_model(onnx_repo_models):  # noqa: F811
    model_path = onnx_repo_models.model_path
    loaded_model = load_model(model_path)
    assert loaded_model == check_load_model(model_path)
    assert loaded_model == check_load_model(loaded_model)


@pytest.mark.parametrize(
    "op_type,inputs,outputs", [("Conv", ["X"], ["Y"]), ("Gemm", ["X"], ["Y", "Z"])]
)
def test_extract_node_id(op_type, inputs, outputs):
    node = make_node(op_type, inputs, outputs)
    assert extract_node_id(node) == outputs[0]


def test_get_node_by_id(simple_onnx_model):
    for node in simple_onnx_model.graph.node:
        node_id = extract_node_id(node)
        assert node == get_node_by_id(simple_onnx_model, node_id)

    assert get_node_by_id(simple_onnx_model, "NONE") is None


def test_get_node_by_input_id(simple_onnx_model):
    last_node = simple_onnx_model.graph.node[-1]
    assert get_nodes_by_input_id(simple_onnx_model, "Y") == [last_node]
    assert get_nodes_by_input_id(simple_onnx_model, "NONE") == []


def test_get_node_by_output_id(simple_onnx_model):
    first_node = simple_onnx_model.graph.node[0]
    assert get_nodes_by_output_id(simple_onnx_model, "Y") == [first_node]
    assert get_nodes_by_output_id(simple_onnx_model, "NONE") == []


def test_extract_shape():
    sample_tensor = make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
    assert extract_shape(sample_tensor) == (3, 2)

    sample_tensor = make_tensor_value_info("X", TensorProto.STRING, None)
    assert extract_shape(sample_tensor) is None


def test_get_numpy_dtype():
    sample_tensor = make_tensor_value_info("X", TensorProto.FLOAT, [3, 2])
    assert get_numpy_dtype(sample_tensor) == numpy.float32

    sample_tensor = make_tensor_value_info("X", TensorProto.INT32, [3, 2])
    assert get_numpy_dtype(sample_tensor) == numpy.int32

    sample_tensor = make_tensor_value_info("X", TensorProto.STRING, None)
    assert get_numpy_dtype(sample_tensor) is None


def test_get_attributes():
    attributes = {
        "kernel": [3, 3],
        "padding": [1, 1, 1, 1],
    }
    node = make_node("Conv", ["X"], ["Y"], **attributes)
    assert get_node_attributes(node) == attributes


def test_get_inputs(simple_onnx_model):
    for node, expected_input in zip(simple_onnx_model.graph.node, [["X"], ["Y"]]):
        assert get_node_inputs(simple_onnx_model, node) == expected_input


def test_get_outputs(simple_onnx_model):
    for node in simple_onnx_model.graph.node:
        assert get_node_outputs(simple_onnx_model, node) == node.output


@pytest.mark.parametrize(
    "node,foldable",
    [
        (make_node("batchnormalization", ["X"], ["Y"]), True),
        (make_node("add", ["X"], ["Y"]), True),
        (make_node("mul", ["X"], ["Y"]), True),
        (make_node("Other", ["X"], ["Y"]), False),
        ("batchnormalization", True),
    ],
)
def test_is_foldable(node, foldable):
    assert is_foldable_node(node) == foldable


def test_get_prunable_node_from_foldable(foldable_onnx_model):
    assert (
        get_prunable_node_from_foldable(
            foldable_onnx_model, foldable_onnx_model.graph.node[-1]
        )
        == foldable_onnx_model.graph.node[0]
    )
    with pytest.raises(ValueError):
        get_prunable_node_from_foldable(
            foldable_onnx_model, foldable_onnx_model.graph.node[0]
        )
    assert (
        get_prunable_node_from_foldable(
            foldable_onnx_model,
            foldable_onnx_model.graph.node[-1],
            traverse_previous=False,
        )
        is None
    )
    assert (
        get_prunable_node_from_foldable(
            foldable_onnx_model, foldable_onnx_model.graph.node[-1], max_node_distance=0
        )
        is None
    )


def test_get_init_by_name(onnx_repo_models):  # noqa: F811
    model = load_model(onnx_repo_models.model_path)
    for init in model.graph.initializer:
        assert init == get_init_by_name(model, init.name)


def test_is_prunable(simple_onnx_model):
    assert is_prunable_node(simple_onnx_model, simple_onnx_model.graph.node[0])
    assert not is_prunable_node(simple_onnx_model, simple_onnx_model.graph.node[-1])


def test_model_inputs(simple_onnx_model):
    assert model_inputs(simple_onnx_model) == list(simple_onnx_model.graph.input)


def test_model_outputs(simple_onnx_model):
    assert model_outputs(simple_onnx_model) == list(simple_onnx_model.graph.output)


def test_get_prunable_nodes(prunable_onnx_model):
    assert (
        get_prunable_nodes(prunable_onnx_model) == prunable_onnx_model.graph.node[:-1]
    )


def test_get_node_input_nodes(simple_onnx_model):
    assert get_node_input_nodes(
        simple_onnx_model, simple_onnx_model.graph.node[-1]
    ) == [simple_onnx_model.graph.node[0]]
    assert (
        get_node_input_nodes(simple_onnx_model, simple_onnx_model.graph.node[0]) == []
    )


def test_get_node_output_nodes(simple_onnx_model):
    assert get_node_output_nodes(
        simple_onnx_model, simple_onnx_model.graph.node[0]
    ) == [simple_onnx_model.graph.node[-1]]
    assert (
        get_node_output_nodes(simple_onnx_model, simple_onnx_model.graph.node[-1]) == []
    )


def test_conv_node_params(prunable_onnx_model):
    conv_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "Conv"
    ][0]
    assert conv_node_params(prunable_onnx_model, conv_node, include_values=False) == (
        NodeParam("node1.weight", None),
        NodeParam("node1.bias", None),
    )
    params = conv_node_params(prunable_onnx_model, conv_node)
    assert params[0][1].shape == (2, 3, 3, 3)
    assert params[1][1].shape == (2,)


def test_gemm_node_params(prunable_onnx_model):
    gemm_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "Gemm"
    ][0]
    assert gemm_node_params(prunable_onnx_model, gemm_node, include_values=False) == (
        NodeParam("node2.weight", None),
        None,
    )
    params = gemm_node_params(prunable_onnx_model, gemm_node)
    assert params[0][1].shape == (12, 3)


def test_matmul_node_params(prunable_onnx_model):
    matmul_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "MatMul"
    ][0]
    assert matmul_node_params(
        prunable_onnx_model, matmul_node, include_values=False
    ) == (NodeParam("node3.weight", None), None)
    params = matmul_node_params(prunable_onnx_model, matmul_node)
    assert params[0][1].shape == (3, 4)


def test_get_node_params(prunable_onnx_model):
    with pytest.raises(ValueError):
        get_node_params(prunable_onnx_model, prunable_onnx_model.graph.node[-1])
    for node, expected_params in zip(
        prunable_onnx_model.graph.node[:-1],
        [
            (NodeParam("node1.weight", None), NodeParam("node1.bias", None)),
            (NodeParam("node2.weight", None), None),
            (NodeParam("node3.weight", None), None),
        ],
    ):
        assert (
            get_node_params(prunable_onnx_model, node, include_values=False)
            == expected_params
        )


def test_onnx_node_sparsities():
    # runs through nearly all other onnx functions imported above as well
    models = Zoo.search_models(
        domain="cv",
        sub_domain="classification",
        architecture="mobilenet_v1",
        dataset="imagenet",
        framework="pytorch",
        sparse_name="pruned",
        sparse_category="moderate",
        repo="sparseml",
    )
    assert len(models) > 0

    for model in models:
        file_path = model.onnx_file.downloaded_path()

        tot, nodes = onnx_nodes_sparsities(file_path)

        assert len(nodes) == 28

        assert isinstance(tot, SparsityMeasurement)
        assert tot.sparsity > 0.5
        assert tot.params_count == 4209088
        assert tot.params_zero_count > 0.5 * tot.params_count

        for node, val in nodes.items():
            assert isinstance(val, SparsityMeasurement)
            assert val.params_count > 0

            if "sections" not in node and "classifier" not in node:
                continue
            if (
                "depth" in node
                or "sections.0" in node
                or "sections_0" in node
                or "sections.1" in node
                or "sections_1" in node
                or "output" in node
            ):
                continue

            assert val.sparsity > 0.2
            assert val.sparsity < 0.95
            assert val.params_zero_count > 0


def test_extract_node_shape(extract_node_models):  # noqa: F811
    model_path, *expected_outputs = extract_node_models
    onnx_model = load_model(model_path)
    node_shapes = extract_node_shapes(onnx_model)

    """
    Depending whether we have test case written for legacy PyTorch
    or both legacy and upgraded PyTorch, three lists below will have:
    `len` of 1 (if only legacy PyTorch test case present)
    `len` of 2 (if test case for legacy and upgraded PyTorch)
    """
    expected_outputs = [x for x in expected_outputs if x]
    correct_input_shapes = [False] * len(expected_outputs)
    correct_output_shapes = [False] * len(expected_outputs)

    for i, expected_output in enumerate(expected_outputs):
        if all(node in node_shapes for node in expected_output):
            correct_input_shapes[i] = all(
                [
                    node_shapes[node].input_shapes == expected_output[node][0]
                    for node in node_shapes
                ]
            )
            correct_output_shapes[i] = all(
                [
                    node_shapes[node].output_shapes == expected_output[node][1]
                    for node in node_shapes
                ]
            )
    """
    If we have only one test case, it must must evaluate to True,
    If we have two test cases, at least one must evaluate to True.
    In other words, we are happy with test passing for legacy or
    upgraded PyTorch (worst case scenario).
    """
    assert any(correct_input_shapes)
    assert any(correct_output_shapes)


@pytest.mark.parametrize(
    "attributes,output",
    [
        ({"kernel": [3, 3]}, [3, 3]),
        ({"kernel_shape": [5, 5]}, [5, 5]),
        ({"stride": [1, 1]}, None),
    ],
)
def test_get_kernel_shape(attributes, output):
    assert get_kernel_shape(attributes) == output


@pytest.mark.parametrize(
    "op_type,input_shape,output_shape,weight_shape,kernel_shape,bias_shape,flops",
    [
        (
            "Add",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Mul",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Div",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Sub",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Clip",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Relu",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "LeakyRelu",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Sigmoid",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "Tanh",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "BatchNormalization",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "GlobalAveragePool",
            [[1, 3, 15, 15]],
            [[1, 3, 1, 1]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "GlobalMaxPool",
            [[1, 3, 15, 15]],
            [[1, 3, 1, 1]],
            None,
            None,
            None,
            3 * 15 * 15,
        ),
        (
            "MaxPool",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            [3, 3],
            None,
            3 * 3 * 3 * 15 * 15,
        ),
        (
            "AveragePool",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            [3, 3],
            None,
            3 * 3 * 3 * 15 * 15,
        ),
        ("MatMul", [[16]], [[8]], [[16, 8]], None, None, 16 * 8 * 2),
        ("MatMul", [[16]], [[8]], [[16, 8]], None, [[8]], 16 * 8 * 2 + 8),
        (
            "MatMul",
            [[9, 5, 7, 4], [9, 5, 4, 3]],
            [[9, 5, 7, 3]],
            None,
            None,
            None,
            (7 * 3 * (2 * 4 - 1)) * 9 * 5,
        ),
        ("Gemm", [[16]], [[8]], [[16, 8]], None, None, 16 * 8 * 2),
        ("Gemm", [[16]], [[8]], [[16, 8]], None, [[8]], 16 * 8 * 2 + 8),
        (
            "Conv",
            [[16, 4, 3, 3]],
            [[16, 16, 2, 2]],
            [16, 4, 3, 3],
            [3, 3],
            [16],
            3 * 3 * 4 * 16 * 16 * 2 * 2 + 16 * 2 * 2,
        ),
        (
            "Conv",
            [[16, 4, 3, 3]],
            [[16, 16, 2, 2]],
            [16, 4, 3, 3],
            [3, 3],
            None,
            3 * 3 * 4 * 16 * 16 * 2 * 2,
        ),
        (
            "Conv",
            [["batch", 4, 3, 3]],
            [["batch", 16, 2, 2]],
            ["batch", 4, 3, 3],
            [3, 3],
            None,
            3 * 3 * 4 * 16 * 2 * 2,
        ),
    ],
)
def test_calculate_flops(
    op_type, input_shape, output_shape, weight_shape, kernel_shape, bias_shape, flops
):
    assert flops == calculate_flops(
        op_type,
        input_shape=input_shape,
        output_shape=output_shape,
        weight_shape=weight_shape,
        kernel_shape=kernel_shape,
        bias_shape=bias_shape,
    )


@pytest.mark.parametrize(
    "op_type,input_shape,output_shape,weight_shape,kernel_shape,bias_shape",
    [
        (
            "Add",
            [[1, 3, 15, 15], [1, 3, 15, 15]],
            None,
            None,
            None,
            None,
        ),
        (
            "GlobalMaxPool",
            None,
            [[1, 3, 1, 1]],
            None,
            None,
            None,
        ),
        (
            "MaxPool",
            [[1, 3, 15, 15]],
            [[1, 3, 15, 15]],
            None,
            None,
            None,
        ),
        ("Gemm", [[16]], [[8]], None, None, None),
        (
            "MatMul",
            [[9, 5, 7, 4], [9, 5, 5, 3]],
            [[9, 5, 7, 3]],
            None,
            None,
            None,
        ),
    ],
)
def test_calculate_flops_negatives(
    op_type, input_shape, output_shape, weight_shape, kernel_shape, bias_shape
):
    assert (
        calculate_flops(
            op_type,
            input_shape=input_shape,
            output_shape=output_shape,
            weight_shape=weight_shape,
            kernel_shape=kernel_shape,
            bias_shape=bias_shape,
        )
        is None
    )
