import onnx
import onnx.checker

if __name__ == '__main__':

    flow = onnx.load("./classifiers/ir_version_7/matmuls_replaced_ir_version_7_flow.onnx")
    # agraph (float[N] A, float[N] B) => (float[N] C, float[N] D)
    #   {
    #      C = Add(A, B)
    #      D = Sub(A, B)
    #   }

    mnist_model = onnx.load("./classifiers/ir_version_7/MnistSimpleClassifier_6_relus.onnx")
    #   agraph (float[N] X, float[N] Y) => (float[N] Z)
    #   {
    #      Z = Mul(X, Y)
    #   }

    combined_model = onnx.compose.merge_models(
        flow, mnist_model,
        io_map=[("98", "onnx::MatMul_0")],
        prefix1="flow",
        prefix2="classifier"
    )
    onnx.save(combined_model, "./classifiers/ir_version_7/MnistSimpleClassifier_6_relus_merged.onnx")
    onnx.checker.check_model(model=combined_model, full_check=True)

