# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import NamedTuple

import pytest
import torch

from sparseml.pytorch.utils import ModuleExporter
from sparsezoo import Zoo
from tests.sparseml.pytorch.helpers import ConvNet, LinearNet, MLPNet


__all__ = [
    "extract_node_models",
    "analyzer_models",
    "onnx_repo_models",
    "GENERATE_TEST_FILES",
]

TEMP_FOLDER = os.path.expanduser(os.path.join("~", ".cache", "nm_models"))
RELATIVE_PATH = os.path.dirname(os.path.realpath(__file__))
GENERATE_TEST_FILES = os.getenv("NM_ML_GENERATE_ONNX_TEST_DATA", False)
GENERATE_TEST_FILES = False if GENERATE_TEST_FILES == "0" else GENERATE_TEST_FILES


@pytest.fixture(
    params=[
        [
            (
                "test_linear_net",
                LinearNet,
                torch.randn(8),
                {
                    "output": ([[8]], [[8]]),
                    "10": ([[8]], [[16]]),
                    "11": ([[16]], [[16]]),
                    "13": ([[16]], [[32]]),
                    "14": ([[32]], [[32]]),
                    "16": ([[32]], [[16]]),
                    "17": ([[16]], [[16]]),
                    "19": ([[16]], [[8]]),
                    "input": (None, [[8]]),
                },
            ),
            None,
        ],
        [
            (
                "test_mlp_net",
                MLPNet,
                torch.randn(8),
                {
                    "output": ([[64]], [[64]]),
                    "8": ([[8]], [[16]]),
                    "9": ([[16]], [[16]]),
                    "10": ([[16]], [[16]]),
                    "12": ([[16]], [[32]]),
                    "13": ([[32]], [[32]]),
                    "14": ([[32]], [[32]]),
                    "16": ([[32]], [[64]]),
                    "17": ([[64]], [[64]]),
                    "input": (None, [[8]]),
                },
            ),
            None,
        ],
        [
            (
                "test_conv_net",
                ConvNet,
                torch.randn(16, 3, 3, 3),
                {
                    "output": ([[16, 10]], [[16, 10]]),
                    "7": ([[16, 3, 3, 3]], [[16, 16, 2, 2]]),
                    "8": ([[16, 16, 2, 2]], [[16, 16, 2, 2]]),
                    "9": ([[16, 16, 2, 2]], [[16, 32, 1, 1]]),
                    "10": ([[16, 32, 1, 1]], [[16, 32, 1, 1]]),
                    "11": ([[16, 32, 1, 1]], [[16, 32, 1, 1]]),
                    "12": ([[16, 32, 1, 1]], [[4]]),
                    "13": (None, None),
                    "14": ([[4]], None),
                    "16": (None, [[1]]),
                    "18": ([[1]], [[2]]),
                    "19": ([[16, 32, 1, 1], [2]], [[16, 32]]),
                    "20": ([[16, 32]], [[16, 10]]),
                    "input": (None, [[16, 3, 3, 3]]),
                },
            ),
            (
                "test_conv_net_upgraded_pytorch",
                ConvNet,
                torch.randn(16, 3, 3, 3),
                {
                    "output": ([[16, 10]], [[16, 10]]),
                    "7": ([[16, 3, 3, 3]], [[16, 16, 2, 2]]),
                    "8": ([[16, 16, 2, 2]], [[16, 16, 2, 2]]),
                    "9": ([[16, 16, 2, 2]], [[16, 32, 1, 1]]),
                    "10": ([[16, 32, 1, 1]], [[16, 32, 1, 1]]),
                    "11": ([[16, 32, 1, 1]], [[16, 32, 1, 1]]),
                    "17": ([[16, 32, 1, 1]], [[16, 32]]),
                    "18": ([[16, 32]], [[16, 10]]),
                    "input": (None, [[16, 3, 3, 3]]),
                },
            ),
        ],
    ]
)
def extract_node_models(request):
    # we assume having two tests
    # - one for old version of PyTorch
    # - one for new version of PyTorch (1.10.2)
    params_python_legacy, params_python_upgrade = request.param

    # check if the test for new PyTorch version test is not `None`
    if params_python_upgrade:
        *_, expected_output_upgrade = params_python_upgrade
    else:
        expected_output_upgrade = None

    (
        model_name,
        model_function,
        sample_batch,
        expected_output_legacy,
    ) = params_python_legacy
    directory = os.path.join(TEMP_FOLDER, model_name)
    os.makedirs(directory, exist_ok=True)
    model_path = os.path.join(directory, "model.onnx")

    if not os.path.exists(model_path):
        module = model_function()
        exporter = ModuleExporter(module, directory)
        exporter.export_onnx(sample_batch=sample_batch)
    return (
        os.path.expanduser(model_path),
        expected_output_legacy,
        expected_output_upgrade,
    )


# TODO update when flops are done
# add a list for exact output
# [python10_output, python9_output]
# check whether each of those pass, if at least one passes than good to go.
@pytest.fixture(
    params=[
        [
            (
                "test_linear_net",
                LinearNet,
                torch.randn(8),
                {
                    "nodes": [
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 256.0,
                            "id": "10",
                            "input_names": ["input"],
                            "input_shapes": [[8]],
                            "op_type": "MatMul",
                            "output_names": ["10"],
                            "output_shapes": [[16]],
                            "params": 128,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.375,
                            "prunable_params": 128,
                            "prunable_params_zeroed": 0,
                            "weight_name": "21",
                            "weight_shape": [8, 16],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 16.0,
                            "id": "11",
                            "input_names": ["10"],
                            "input_shapes": [[16]],
                            "op_type": "Add",
                            "output_names": ["11"],
                            "output_shapes": [[16]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 1024.0,
                            "id": "13",
                            "input_names": ["11"],
                            "input_shapes": [[16]],
                            "op_type": "MatMul",
                            "output_names": ["13"],
                            "output_shapes": [[32]],
                            "params": 512,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.1875,
                            "prunable_params": 512,
                            "prunable_params_zeroed": 0,
                            "weight_name": "22",
                            "weight_shape": [16, 32],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 32.0,
                            "id": "14",
                            "input_names": ["13"],
                            "input_shapes": [[32]],
                            "op_type": "Add",
                            "output_names": ["14"],
                            "output_shapes": [[32]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 1024.0,
                            "id": "16",
                            "input_names": ["14"],
                            "input_shapes": [[32]],
                            "op_type": "MatMul",
                            "output_names": ["16"],
                            "output_shapes": [[16]],
                            "params": 512,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.1875,
                            "prunable_params": 512,
                            "prunable_params_zeroed": 0,
                            "weight_name": "23",
                            "weight_shape": [32, 16],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 16.0,
                            "id": "17",
                            "input_names": ["16"],
                            "input_shapes": [[16]],
                            "op_type": "Add",
                            "output_names": ["17"],
                            "output_shapes": [[16]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 256.0,
                            "id": "19",
                            "input_names": ["17"],
                            "input_shapes": [[16]],
                            "op_type": "MatMul",
                            "output_names": ["19"],
                            "output_shapes": [[8]],
                            "params": 128,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.375,
                            "prunable_params": 128,
                            "prunable_params_zeroed": 0,
                            "weight_name": "24",
                            "weight_shape": [16, 8],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 8.0,
                            "id": "output",
                            "input_names": ["19"],
                            "input_shapes": [[8]],
                            "op_type": "Add",
                            "output_names": ["output"],
                            "output_shapes": [[8]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                    ]
                },
            ),
            None,
        ],
        [
            (
                "test_mlp_net",
                MLPNet,
                torch.randn(8),
                {
                    "nodes": [
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 256.0,
                            "id": "8",
                            "input_names": ["input"],
                            "input_shapes": [[8]],
                            "op_type": "MatMul",
                            "output_names": ["8"],
                            "output_shapes": [[16]],
                            "params": 128,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.375,
                            "prunable_params": 128,
                            "prunable_params_zeroed": 0,
                            "weight_name": "19",
                            "weight_shape": [8, 16],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 16.0,
                            "id": "9",
                            "input_names": ["8"],
                            "input_shapes": [[16]],
                            "op_type": "Add",
                            "output_names": ["9"],
                            "output_shapes": [[16]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 16.0,
                            "id": "10",
                            "input_names": ["9"],
                            "input_shapes": [[16]],
                            "op_type": "Relu",
                            "output_names": ["10"],
                            "output_shapes": [[16]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 1024.0,
                            "id": "12",
                            "input_names": ["10"],
                            "input_shapes": [[16]],
                            "op_type": "MatMul",
                            "output_names": ["12"],
                            "output_shapes": [[32]],
                            "params": 512,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.1875,
                            "prunable_params": 512,
                            "prunable_params_zeroed": 0,
                            "weight_name": "20",
                            "weight_shape": [16, 32],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 32.0,
                            "id": "13",
                            "input_names": ["12"],
                            "input_shapes": [[32]],
                            "op_type": "Add",
                            "output_names": ["13"],
                            "output_shapes": [[32]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 32.0,
                            "id": "14",
                            "input_names": ["13"],
                            "input_shapes": [[32]],
                            "op_type": "Relu",
                            "output_names": ["14"],
                            "output_shapes": [[32]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 4096.0,
                            "id": "16",
                            "input_names": ["14"],
                            "input_shapes": [[32]],
                            "op_type": "MatMul",
                            "output_names": ["16"],
                            "output_shapes": [[64]],
                            "params": 2048,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.09375,
                            "prunable_params": 2048,
                            "prunable_params_zeroed": 0,
                            "weight_name": "21",
                            "weight_shape": [32, 64],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 64.0,
                            "id": "17",
                            "input_names": ["16"],
                            "input_shapes": [[64]],
                            "op_type": "Add",
                            "output_names": ["17"],
                            "output_shapes": [[64]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 64.0,
                            "id": "output",
                            "input_names": ["17"],
                            "input_shapes": [[64]],
                            "op_type": "Sigmoid",
                            "output_names": ["output"],
                            "output_shapes": [[64]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                    ]
                },
            ),
            None,
        ],
        [
            (
                "test_conv_net",
                ConvNet,
                torch.randn(16, 3, 3, 3),
                {
                    "nodes": [
                        {
                            "attributes": {
                                "dilations": [1, 1],
                                "group": 1,
                                "kernel_shape": [3, 3],
                                "pads": [1, 1, 1, 1],
                                "strides": [2, 2],
                            },
                            "bias_name": "seq.conv1.bias",
                            "bias_shape": [16],
                            "flops": 27712.0,
                            "id": "7",
                            "input_names": ["input"],
                            "input_shapes": [[16, 3, 3, 3]],
                            "op_type": "Conv",
                            "output_names": ["7"],
                            "output_shapes": [[16, 16, 2, 2]],
                            "params": 448,
                            "prunable": True,
                            "prunable_equation_sensitivity": 7.703703703703703,
                            "prunable_params": 432,
                            "prunable_params_zeroed": 0,
                            "weight_name": "seq.conv1.weight",
                            "weight_shape": [16, 3, 3, 3],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 1024.0,
                            "id": "8",
                            "input_names": ["7"],
                            "input_shapes": [[16, 16, 2, 2]],
                            "op_type": "Relu",
                            "output_names": ["8"],
                            "output_shapes": [[16, 16, 2, 2]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {
                                "dilations": [1, 1],
                                "group": 1,
                                "kernel_shape": [3, 3],
                                "pads": [1, 1, 1, 1],
                                "strides": [2, 2],
                            },
                            "bias_name": "seq.conv2.bias",
                            "bias_shape": [32],
                            "flops": 73760.0,
                            "id": "9",
                            "input_names": ["8"],
                            "input_shapes": [[16, 16, 2, 2]],
                            "op_type": "Conv",
                            "output_names": ["9"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 4640,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.6620689655172414,
                            "prunable_params": 4608,
                            "prunable_params_zeroed": 0,
                            "weight_name": "seq.conv2.weight",
                            "weight_shape": [32, 16, 3, 3],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 512.0,
                            "id": "10",
                            "input_names": ["9"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "Relu",
                            "output_names": ["10"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 512.0,
                            "id": "11",
                            "input_names": ["10"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "GlobalAveragePool",
                            "output_names": ["11"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "12",
                            "input_names": ["11"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "Shape",
                            "output_names": ["12"],
                            "output_shapes": [[4]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"value": None},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "13",
                            "input_names": [],
                            "input_shapes": None,
                            "op_type": "Constant",
                            "output_names": ["13"],
                            "output_shapes": None,
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"axis": 0},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "14",
                            "input_names": ["12", "13"],
                            "input_shapes": [[4]],
                            "op_type": "Gather",
                            "output_names": ["14"],
                            "output_shapes": None,
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"axes": [0]},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "16",
                            "input_names": ["14"],
                            "input_shapes": None,
                            "op_type": "Unsqueeze",
                            "output_names": ["16"],
                            "output_shapes": [[1]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"axis": 0},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "18",
                            "input_names": ["16"],
                            "input_shapes": [[1]],
                            "op_type": "Concat",
                            "output_names": ["18"],
                            "output_shapes": [[2]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "19",
                            "input_names": ["11", "18"],
                            "input_shapes": [[16, 32, 1, 1], [2]],
                            "op_type": "Reshape",
                            "output_names": ["19"],
                            "output_shapes": [[16, 32]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"alpha": 1.0, "beta": 1.0, "transB": 1},
                            "bias_name": "mlp.fc.bias",
                            "bias_shape": [10],
                            "flops": 650.0,
                            "id": "20",
                            "input_names": ["19"],
                            "input_shapes": [[16, 32]],
                            "op_type": "Gemm",
                            "output_names": ["20"],
                            "output_shapes": [[16, 10]],
                            "params": 330,
                            "prunable": True,
                            "prunable_equation_sensitivity": 6.516363636363636,
                            "prunable_params": 320,
                            "prunable_params_zeroed": 0,
                            "weight_name": "mlp.fc.weight",
                            "weight_shape": [10, 32],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 160.0,
                            "id": "output",
                            "input_names": ["20"],
                            "input_shapes": [[16, 10]],
                            "op_type": "Sigmoid",
                            "output_names": ["output"],
                            "output_shapes": [[16, 10]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                    ]
                },
            ),
            (
                "test_conv_net_upgrade",
                ConvNet,
                torch.randn(16, 3, 3, 3),
                {
                    "nodes": [
                        {
                            "attributes": {
                                "dilations": [1, 1],
                                "group": 1,
                                "kernel_shape": [3, 3],
                                "pads": [1, 1, 1, 1],
                                "strides": [2, 2],
                            },
                            "bias_name": "seq.conv1.bias",
                            "bias_shape": [16],
                            "flops": 27712.0,
                            "id": "7",
                            "input_names": ["input"],
                            "input_shapes": [[16, 3, 3, 3]],
                            "op_type": "Conv",
                            "output_names": ["7"],
                            "output_shapes": [[16, 16, 2, 2]],
                            "params": 448,
                            "prunable": True,
                            "prunable_equation_sensitivity": 7.703703703703703,
                            "prunable_params": 432,
                            "prunable_params_zeroed": 0,
                            "weight_name": "seq.conv1.weight",
                            "weight_shape": [16, 3, 3, 3],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 1024.0,
                            "id": "8",
                            "input_names": ["7"],
                            "input_shapes": [[16, 16, 2, 2]],
                            "op_type": "Relu",
                            "output_names": ["8"],
                            "output_shapes": [[16, 16, 2, 2]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {
                                "dilations": [1, 1],
                                "group": 1,
                                "kernel_shape": [3, 3],
                                "pads": [1, 1, 1, 1],
                                "strides": [2, 2],
                            },
                            "bias_name": "seq.conv2.bias",
                            "bias_shape": [32],
                            "flops": 73760.0,
                            "id": "9",
                            "input_names": ["8"],
                            "input_shapes": [[16, 16, 2, 2]],
                            "op_type": "Conv",
                            "output_names": ["9"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 4640,
                            "prunable": True,
                            "prunable_equation_sensitivity": 0.6620689655172414,
                            "prunable_params": 4608,
                            "prunable_params_zeroed": 0,
                            "weight_name": "seq.conv2.weight",
                            "weight_shape": [32, 16, 3, 3],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 512.0,
                            "id": "10",
                            "input_names": ["9"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "Relu",
                            "output_names": ["10"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 512.0,
                            "id": "11",
                            "input_names": ["10"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "GlobalAveragePool",
                            "output_names": ["11"],
                            "output_shapes": [[16, 32, 1, 1]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": None,
                            "id": "17",
                            "input_names": ["11"],
                            "input_shapes": [[16, 32, 1, 1]],
                            "op_type": "Reshape",
                            "output_names": ["17"],
                            "output_shapes": [[16, 32]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                        {
                            "attributes": {"alpha": 1.0, "beta": 1.0, "transB": 1},
                            "bias_name": "mlp.fc.bias",
                            "bias_shape": [10],
                            "flops": 650.0,
                            "id": "18",
                            "input_names": ["17"],
                            "input_shapes": [[16, 32]],
                            "op_type": "Gemm",
                            "output_names": ["18"],
                            "output_shapes": [[16, 10]],
                            "params": 330,
                            "prunable": True,
                            "prunable_equation_sensitivity": 6.516363636363636,
                            "prunable_params": 320,
                            "prunable_params_zeroed": 0,
                            "weight_name": "mlp.fc.weight",
                            "weight_shape": [10, 32],
                        },
                        {
                            "attributes": {},
                            "bias_name": None,
                            "bias_shape": None,
                            "flops": 160.0,
                            "id": "output",
                            "input_names": ["18"],
                            "input_shapes": [[16, 10]],
                            "op_type": "Sigmoid",
                            "output_names": ["output"],
                            "output_shapes": [[16, 10]],
                            "params": 0,
                            "prunable": False,
                            "prunable_equation_sensitivity": None,
                            "prunable_params": -1,
                            "prunable_params_zeroed": 0,
                            "weight_name": None,
                            "weight_shape": None,
                        },
                    ]
                },
            ),
        ],
    ]
)
def analyzer_models(request):
    data_legacy_python, data_upgrade_python = request.param
    if data_upgrade_python:
        *_, expected_output_upgrade = data_upgrade_python
    else:
        expected_output_upgrade = None

    (
        model_name,
        model_function,
        sample_batch,
        expected_output_legacy,
    ) = data_legacy_python
    directory = os.path.join(TEMP_FOLDER, model_name)
    os.makedirs(directory, exist_ok=True)
    model_path = os.path.join(directory, "model.onnx")

    if not os.path.exists(model_path):
        module = model_function()
        exporter = ModuleExporter(module, directory)
        exporter.export_onnx(sample_batch=sample_batch)
    return (
        os.path.expanduser(model_path),
        expected_output_legacy,
        expected_output_upgrade,
    )


OnnxRepoModelFixture = NamedTuple(
    "OnnxRepoModelFixture",
    [
        ("model_path", str),
        ("model_name", str),
        ("input_paths", str),
        ("output_paths", str),
    ],
)


@pytest.fixture(
    scope="session",
    params=[
        (
            {
                "domain": "cv",
                "sub_domain": "classification",
                "architecture": "resnet_v1",
                "sub_architecture": "50",
                "framework": "pytorch",
                "repo": "sparseml",
                "dataset": "imagenet",
                "training_scheme": None,
                "sparse_name": "base",
                "sparse_category": "none",
                "sparse_target": None,
            },
            "resnet50",
        ),
        (
            {
                "domain": "cv",
                "sub_domain": "classification",
                "architecture": "mobilenet_v1",
                "sub_architecture": "1.0",
                "framework": "pytorch",
                "repo": "sparseml",
                "dataset": "imagenet",
                "training_scheme": None,
                "sparse_name": "base",
                "sparse_category": "none",
                "sparse_target": None,
            },
            "mobilenet",
        ),
    ],
)
def onnx_repo_models(request) -> OnnxRepoModelFixture:
    model_args, model_name = request.param
    model = Zoo.load_model(**model_args)
    model_path = model.onnx_file.downloaded_path()
    data_paths = [data_file.downloaded_path() for data_file in model.data.values()]

    input_paths = None
    output_paths = None
    for path in data_paths:
        if "sample-inputs" in path:
            input_paths = path
        elif "sample-outputs" in path:
            output_paths = path
    return OnnxRepoModelFixture(model_path, model_name, input_paths, output_paths)
