{
  "metadata": {
    "total": 217,
    "pass": 209,
    "fail": 8
  },
  "results": [
    {
      "name": "100_HingeLoss",
      "status": "pass",
      "avg_ms": 3.829,
      "std_ms": 1.1613,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "10_3D_tensor_matrix_multiplication",
      "status": "pass",
      "avg_ms": 0.2428,
      "std_ms": 0.0033,
      "output_shape": "((16, 1024, 768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "11_4D_tensor_matrix_multiplication",
      "status": "pass",
      "avg_ms": 3.4623,
      "std_ms": 0.4807,
      "output_shape": "((8, 256, 512, 768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "12_Matmul_with_diagonal_matrices_",
      "status": "pass",
      "avg_ms": 0.2011,
      "std_ms": 0.003,
      "output_shape": "((4096, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "13_Matmul_for_symmetric_matrices",
      "status": "pass",
      "avg_ms": 0.2834,
      "std_ms": 0.0033,
      "output_shape": "((4096, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "14_Matmul_for_upper_triangular_matrices",
      "status": "pass",
      "avg_ms": 0.2815,
      "std_ms": 0.0076,
      "output_shape": "((4096, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "15_Matmul_for_lower_triangular_matrices",
      "status": "pass",
      "avg_ms": 0.2778,
      "std_ms": 0.0023,
      "output_shape": "((4096, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "16_Matmul_with_transposed_A",
      "status": "pass",
      "avg_ms": 0.284,
      "std_ms": 0.0033,
      "output_shape": "((2048, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "17_Matmul_with_transposed_B",
      "status": "pass",
      "avg_ms": 0.2977,
      "std_ms": 0.0053,
      "output_shape": "((2048, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "18_Matmul_with_transposed_both",
      "status": "pass",
      "avg_ms": 0.3069,
      "std_ms": 0.002,
      "output_shape": "((2048, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "19_ReLU",
      "status": "pass",
      "avg_ms": 10.9979,
      "std_ms": 0.03,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "1_Square_matrix_multiplication_",
      "status": "pass",
      "avg_ms": 0.2776,
      "std_ms": 0.0027,
      "output_shape": "((4096, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "20_LeakyReLU",
      "status": "pass",
      "avg_ms": 10.9899,
      "std_ms": 0.0701,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "21_Sigmoid",
      "status": "pass",
      "avg_ms": 10.9886,
      "std_ms": 0.0465,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "22_Tanh",
      "status": "pass",
      "avg_ms": 10.9721,
      "std_ms": 0.0248,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "23_Softmax",
      "status": "pass",
      "avg_ms": 22.7029,
      "std_ms": 0.3779,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "24_LogSoftmax",
      "status": "pass",
      "avg_ms": 22.6984,
      "std_ms": 0.4055,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "25_Swish",
      "status": "pass",
      "avg_ms": 10.9791,
      "std_ms": 0.1617,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "26_GELU_",
      "status": "pass",
      "avg_ms": 11.0099,
      "std_ms": 0.0918,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "27_SELU_",
      "status": "pass",
      "avg_ms": 11.0052,
      "std_ms": 0.3134,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "28_HardSigmoid",
      "status": "pass",
      "avg_ms": 11.0339,
      "std_ms": 0.4841,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "29_Softplus",
      "status": "pass",
      "avg_ms": 10.9915,
      "std_ms": 0.0337,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "2_Standard_matrix_multiplication_",
      "status": "pass",
      "avg_ms": 0.2822,
      "std_ms": 0.0035,
      "output_shape": "((2048, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "30_Softsign",
      "status": "pass",
      "avg_ms": 11.0417,
      "std_ms": 0.3301,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "31_ELU",
      "status": "pass",
      "avg_ms": 10.993,
      "std_ms": 0.0795,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "32_HardTanh",
      "status": "pass",
      "avg_ms": 10.9783,
      "std_ms": 0.0394,
      "output_shape": "((4096, 393216), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "33_BatchNorm",
      "status": "pass",
      "avg_ms": 7.3546,
      "std_ms": 0.4241,
      "output_shape": "((64, 64, 512, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "34_InstanceNorm",
      "status": "pass",
      "avg_ms": 0.9784,
      "std_ms": 0.1988,
      "output_shape": "((16, 64, 256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "35_GroupNorm_",
      "status": "pass",
      "avg_ms": 0.9726,
      "std_ms": 0.1809,
      "output_shape": "((16, 64, 256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "36_RMSNorm_",
      "status": "pass",
      "avg_ms": 19.9812,
      "std_ms": 0.4135,
      "output_shape": "((112, 64, 512, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "37_FrobeniusNorm_",
      "status": "pass",
      "avg_ms": 19.6123,
      "std_ms": 0.6023,
      "output_shape": "((112, 64, 512, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "38_L1Norm_",
      "status": "pass",
      "avg_ms": 0.3948,
      "std_ms": 0.013,
      "output_shape": "((4096, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "39_L2Norm_",
      "status": "pass",
      "avg_ms": 0.3843,
      "std_ms": 0.0065,
      "output_shape": "((4096, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "3_Batched_matrix_multiplication",
      "status": "pass",
      "avg_ms": 1.6277,
      "std_ms": 0.298,
      "output_shape": "((128, 512, 2048), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "40_LayerNorm",
      "status": "pass",
      "avg_ms": 0.9804,
      "std_ms": 0.2023,
      "output_shape": "((16, 64, 256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "41_Max_Pooling_1D",
      "status": "pass",
      "avg_ms": 0.2892,
      "std_ms": 0.0026,
      "output_shape": "((16, 64, 8179), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "42_Max_Pooling_2D",
      "status": "pass",
      "avg_ms": 18.2093,
      "std_ms": 0.0184,
      "output_shape": "((32, 64, 511, 511), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "43_Max_Pooling_3D",
      "status": "pass",
      "avg_ms": 125.6242,
      "std_ms": 0.0302,
      "output_shape": "((16, 32, 62, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "44_Average_Pooling_1D",
      "status": "pass",
      "avg_ms": 7.2843,
      "std_ms": 0.3448,
      "output_shape": "((64, 128, 65537), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "45_Average_Pooling_2D",
      "status": "pass",
      "avg_ms": 0.4564,
      "std_ms": 0.0037,
      "output_shape": "((4, 32, 46, 46), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "46_Average_Pooling_3D",
      "status": "pass",
      "avg_ms": 24.646,
      "std_ms": 0.021,
      "output_shape": "((16, 32, 64, 64, 128), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "47_Sum_reduction_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6497,
      "std_ms": 0.8977,
      "output_shape": "((128, 1, 4095), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "48_Mean_reduction_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6432,
      "std_ms": 0.9112,
      "output_shape": "((128, 4095), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "49_Max_reduction_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6445,
      "std_ms": 0.9345,
      "output_shape": "((128, 4095), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "4_Matrix_vector_multiplication_",
      "status": "pass",
      "avg_ms": 7.6173,
      "std_ms": 0.979,
      "output_shape": "((2048, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "50_conv_standard_2D__square_input__square_kernel",
      "status": "fail",
      "error": "transpose requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.",
      "traceback": "py\", line 1197, in transpose\n    a = util.ensure_arraylike(\"transpose\", a)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 155, in ensure_arraylike\n    check_arraylike(fun_name, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 181, in check_arraylike\n    raise TypeError(msg.format(fun_name, type(arg), pos))\nTypeError: transpose requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.\n",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "51_Argmax_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6917,
      "std_ms": 0.8107,
      "output_shape": "((128, 4095), dtype('int32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "52_Argmin_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6694,
      "std_ms": 0.8033,
      "output_shape": "((128, 4095), dtype('int32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "53_Min_reduction_over_a_dimension",
      "status": "pass",
      "avg_ms": 7.6458,
      "std_ms": 0.8978,
      "output_shape": "((128, 4095), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "54_conv_standard_3D__square_input__square_kernel",
      "status": "pass",
      "avg_ms": 2.544,
      "std_ms": 0.0046,
      "output_shape": "((16, 64, 62, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "55_conv_standard_2D__asymmetric_input__square_kernel",
      "status": "fail",
      "error": "'NoneType' object has no attribute 'shape'",
      "traceback": " following exception. Set JAX_TRACEBACK_FILTERING=off to include these.\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n  File \"/tmp/pallas_eval/baseline_harness.py\", line 20, in <module>\n    out = jax.jit(model.forward)(*inputs)\n  File \"/tmp/pallas_eval/baselines/55_conv_standard_2D__asymmetric_input__square_kernel.py\", line 41, in forward\n    out = jax.lax.conv_general_dilated(\nAttributeError: 'NoneType' object has no attribute 'shape'\n",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "56_conv_standard_2D__asymmetric_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 1.8684,
      "std_ms": 0.0051,
      "output_shape": "((8, 128, 508, 250), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "57_conv_transposed_2D__square_input__square_kernel",
      "status": "pass",
      "avg_ms": 15.0057,
      "std_ms": 0.296,
      "output_shape": "((8, 64, 1026, 1026), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "58_conv_transposed_3D__asymmetric_input__asymmetric_kernel",
      "status": "fail",
      "error": "no JSON output",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "59_conv_standard_3D__asymmetric_input__square_kernel",
      "status": "pass",
      "avg_ms": 12.445,
      "std_ms": 0.3378,
      "output_shape": "((16, 64, 254, 254, 10), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "5_Matrix_scalar_multiplication",
      "status": "pass",
      "avg_ms": 7.3591,
      "std_ms": 0.3047,
      "output_shape": "((65536, 16384), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "60_conv_standard_3D__square_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 6.266,
      "std_ms": 0.0076,
      "output_shape": "((16, 64, 62, 60, 58), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "61_conv_transposed_3D__square_input__square_kernel",
      "status": "pass",
      "avg_ms": 5.6004,
      "std_ms": 0.0077,
      "output_shape": "((8, 48, 66, 66, 66), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "62_conv_standard_2D__square_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 5.0684,
      "std_ms": 0.1838,
      "output_shape": "((8, 64, 508, 504), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "63_conv_standard_2D__square_input__square_kernel",
      "status": "pass",
      "avg_ms": 15.2761,
      "std_ms": 0.1311,
      "output_shape": "((16, 128, 1022, 1022), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "64_conv_transposed_1D",
      "status": "pass",
      "avg_ms": 5.3677,
      "std_ms": 0.5249,
      "output_shape": "((64, 128, 65538), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "65_conv_transposed_2D__square_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 5.6899,
      "std_ms": 0.0063,
      "output_shape": "((8, 64, 514, 518), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "66_conv_standard_3D__asymmetric_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 6.632,
      "std_ms": 0.0102,
      "output_shape": "((8, 64, 14, 124, 122), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "67_conv_standard_1D",
      "status": "pass",
      "avg_ms": 4.5141,
      "std_ms": 0.1425,
      "output_shape": "((32, 128, 131070), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "68_conv_transposed_3D__square_input__asymmetric_kernel",
      "status": "fail",
      "error": "no JSON output",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "69_conv_transposed_2D__asymmetric_input__asymmetric_kernel",
      "status": "pass",
      "avg_ms": 2.4527,
      "std_ms": 0.271,
      "output_shape": "((64, 128, 130, 260), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "6_Matmul_with_large_K_dimension_",
      "status": "pass",
      "avg_ms": 0.9789,
      "std_ms": 0.2194,
      "output_shape": "((256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "70_conv_transposed_3D__asymmetric_input__square_kernel",
      "status": "pass",
      "avg_ms": 15.9056,
      "std_ms": 0.017,
      "output_shape": "((8, 24, 98, 98, 98), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "71_conv_transposed_2D__asymmetric_input__square_kernel",
      "status": "pass",
      "avg_ms": 6.4624,
      "std_ms": 0.3841,
      "output_shape": "((8, 32, 514, 1026), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_",
      "status": "pass",
      "avg_ms": 5.2747,
      "std_ms": 0.0077,
      "output_shape": "((8, 32, 24, 48, 96), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "73_conv_transposed_3D_asymmetric_input_square_kernel__strided_padded__grouped",
      "status": "pass",
      "avg_ms": 14.1251,
      "std_ms": 0.0897,
      "output_shape": "((4, 32, 63, 127, 255), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "74_conv_transposed_1D_dilated",
      "status": "pass",
      "avg_ms": 7.0345,
      "std_ms": 0.0368,
      "output_shape": "((32, 64, 131084), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "75_conv_transposed_2D_asymmetric_input_asymmetric_kernel_strided__grouped____padded____dilated__",
      "status": "pass",
      "avg_ms": 14.6285,
      "std_ms": 0.2476,
      "output_shape": "((16, 64, 257, 766), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "76_conv_standard_1D_dilated_strided__",
      "status": "pass",
      "avg_ms": 0.8877,
      "std_ms": 0.0036,
      "output_shape": "((16, 128, 21843), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "77_conv_transposed_3D_square_input_square_kernel___padded____dilated____strided__",
      "status": "pass",
      "avg_ms": 7.5943,
      "std_ms": 0.0122,
      "output_shape": "((16, 64, 33, 65, 65), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "78_conv_transposed_2D_asymmetric_input_asymmetric_kernel___padded__",
      "status": "fail",
      "error": "conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 32 // 1 != 3.",
      "traceback": "ect cause of the following exception:\n\nTraceback (most recent call last):\n  File \"/tmp/pallas_eval/baseline_harness.py\", line 20, in <module>\n    out = jax.jit(model.forward)(*inputs)\n  File \"/tmp/pallas_eval/baselines/78_conv_transposed_2D_asymmetric_input_asymmetric_kernel___padded__.py\", line 44, in forward\n    out = lax.conv_transpose(\nValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 32 // 1 != 3.\n",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated__",
      "status": "pass",
      "avg_ms": 5.6759,
      "std_ms": 0.2847,
      "output_shape": "((16, 64, 262145), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "7_Matmul_with_small_K_dimension_",
      "status": "pass",
      "avg_ms": 3.2881,
      "std_ms": 0.5103,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "80_conv_standard_2D_square_input_asymmetric_kernel___dilated____padded__",
      "status": "pass",
      "avg_ms": 4.683,
      "std_ms": 0.0063,
      "output_shape": "((8, 64, 508, 496), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "81_conv_transposed_2D_asymmetric_input_square_kernel___dilated____padded____strided__",
      "status": "pass",
      "avg_ms": 3.4358,
      "std_ms": 0.1703,
      "output_shape": "((16, 64, 318, 638), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "82_conv_depthwise_2D_square_input_square_kernel",
      "status": "pass",
      "avg_ms": 7.3544,
      "std_ms": 0.4916,
      "output_shape": "((16, 64, 510, 510), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "83_conv_depthwise_2D_square_input_asymmetric_kernel",
      "status": "pass",
      "avg_ms": 4.1634,
      "std_ms": 0.0061,
      "output_shape": "((64, 8, 510, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "84_conv_depthwise_2D_asymmetric_input_square_kernel",
      "status": "pass",
      "avg_ms": 12.2313,
      "std_ms": 0.0169,
      "output_shape": "((64, 128, 254, 510), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "85_conv_depthwise_2D_asymmetric_input_asymmetric_kernel",
      "status": "pass",
      "avg_ms": 1.805,
      "std_ms": 0.0027,
      "output_shape": "((32, 128, 126, 250), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "86_conv_depthwise_separable_2D",
      "status": "pass",
      "avg_ms": 8.3812,
      "std_ms": 0.0609,
      "output_shape": "((16, 128, 512, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "87_conv_pointwise_2D",
      "status": "pass",
      "avg_ms": 3.1693,
      "std_ms": 0.0586,
      "output_shape": "((4, 128, 512, 512), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "88_MinGPTNewGelu",
      "status": "pass",
      "avg_ms": 0.4921,
      "std_ms": 0.0049,
      "output_shape": "((8192, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "89_cumsum",
      "status": "pass",
      "avg_ms": 58.1258,
      "std_ms": 0.2631,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "8_Matmul_with_irregular_shapes_",
      "status": "pass",
      "avg_ms": 0.5913,
      "std_ms": 0.0011,
      "output_shape": "((8205, 5921), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "90_cumprod",
      "status": "pass",
      "avg_ms": 58.9747,
      "std_ms": 0.3373,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "91_cumsum_reverse",
      "status": "pass",
      "avg_ms": 117.7057,
      "std_ms": 3.6913,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "92_cumsum_exclusive",
      "status": "pass",
      "avg_ms": 72.1991,
      "std_ms": 0.0674,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "93_masked_cumsum",
      "status": "pass",
      "avg_ms": 65.8848,
      "std_ms": 0.1943,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "94_MSELoss",
      "status": "pass",
      "avg_ms": 7.641,
      "std_ms": 0.9434,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "95_CrossEntropyLoss",
      "status": "pass",
      "avg_ms": 2.0481,
      "std_ms": 0.0044,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "96_HuberLoss",
      "status": "pass",
      "avg_ms": 0.2035,
      "std_ms": 0.0026,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "97_ScaledDotProductAttention",
      "status": "pass",
      "avg_ms": 10.3519,
      "std_ms": 0.5035,
      "output_shape": "((32, 32, 512, 1024), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "98_KLDivLoss",
      "status": "pass",
      "avg_ms": 2.0231,
      "std_ms": 0.5565,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "99_TripletMarginLoss",
      "status": "pass",
      "avg_ms": 2.8609,
      "std_ms": 0.8194,
      "output_shape": "((), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "9_Tall_skinny_matrix_multiplication_",
      "status": "pass",
      "avg_ms": 3.2858,
      "std_ms": 0.42,
      "output_shape": "((32768, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level1"
    },
    {
      "name": "100_ConvTranspose3d_Clamp_Min_Divide",
      "status": "pass",
      "avg_ms": 5.1176,
      "std_ms": 0.3862,
      "output_shape": "((16, 128, 47, 95, 95), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "10_ConvTranspose2d_MaxPool_Hardtanh_Mean_Tanh",
      "status": "pass",
      "avg_ms": 9.1399,
      "std_ms": 0.0089,
      "output_shape": "((128, 64, 1, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "11_ConvTranspose2d_BatchNorm_Tanh_MaxPool_GroupNorm",
      "status": "pass",
      "avg_ms": 1.2249,
      "std_ms": 0.1033,
      "output_shape": "((512, 128, 17, 17), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "12_Gemm_Multiply_LeakyReLU",
      "status": "pass",
      "avg_ms": 0.1643,
      "std_ms": 0.0163,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "13_ConvTranspose3d_Mean_Add_Softmax_Tanh_Scaling",
      "status": "pass",
      "avg_ms": 13.1106,
      "std_ms": 0.0387,
      "output_shape": "((16, 64, 1, 128, 128), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "14_Gemm_Divide_Sum_Scaling",
      "status": "pass",
      "avg_ms": 1.115,
      "std_ms": 0.0039,
      "output_shape": "((1024, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "15_ConvTranspose3d_BatchNorm_Subtract",
      "status": "pass",
      "avg_ms": 8.7207,
      "std_ms": 0.0528,
      "output_shape": "((16, 32, 31, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "16_ConvTranspose2d_Mish_Add_Hardtanh_Scaling",
      "status": "pass",
      "avg_ms": 8.1754,
      "std_ms": 0.0072,
      "output_shape": "((128, 64, 256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "17_Conv2d_InstanceNorm_Divide",
      "status": "pass",
      "avg_ms": 4.6186,
      "std_ms": 0.1531,
      "output_shape": "((128, 128, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp",
      "status": "pass",
      "avg_ms": 0.1278,
      "std_ms": 0.0042,
      "output_shape": "((1024, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "19_ConvTranspose2d_GELU_GroupNorm",
      "status": "pass",
      "avg_ms": 12.3441,
      "std_ms": 0.1264,
      "output_shape": "((128, 64, 258, 258), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "1_Conv2D_ReLU_BiasAdd",
      "status": "pass",
      "avg_ms": 2.1877,
      "std_ms": 0.0051,
      "output_shape": "((128, 128, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAdd",
      "status": "pass",
      "avg_ms": 4.4982,
      "std_ms": 0.3459,
      "output_shape": "((16, 64, 32, 64, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "21_Conv2d_Add_Scale_Sigmoid_GroupNorm",
      "status": "pass",
      "avg_ms": 4.0375,
      "std_ms": 0.2115,
      "output_shape": "((128, 32, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish",
      "status": "fail",
      "error": "'Model' object has no attribute 'matmul_weight'",
      "traceback": "t JAX_TRACEBACK_FILTERING=off to include these.\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n  File \"/tmp/pallas_eval/baseline_harness.py\", line 20, in <module>\n    out = jax.jit(model.forward)(*inputs)\n  File \"/tmp/pallas_eval/baselines/22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish.py\", line 24, in forward\n    x = jnp.matmul(x, self.matmul_weight.T) + self.matmul_bias\nAttributeError: 'Model' object has no attribute 'matmul_weight'\n",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "23_Conv3d_GroupNorm_Mean",
      "status": "pass",
      "avg_ms": 0.6565,
      "std_ms": 0.0068,
      "output_shape": "((128,), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "24_Conv3d_Min_Softmax",
      "status": "pass",
      "avg_ms": 0.3596,
      "std_ms": 0.0044,
      "output_shape": "((128, 24, 30, 30), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "25_Conv2d_Min_Tanh_Tanh",
      "status": "pass",
      "avg_ms": 1.9988,
      "std_ms": 0.0024,
      "output_shape": "((128, 1, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "26_ConvTranspose3d_Add_HardSwish",
      "status": "pass",
      "avg_ms": 4.0705,
      "std_ms": 0.617,
      "output_shape": "((128, 64, 32, 32, 32), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "27_Conv3d_HardSwish_GroupNorm_Mean",
      "status": "pass",
      "avg_ms": 2.908,
      "std_ms": 0.0054,
      "output_shape": "((1024, 16), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply",
      "status": "pass",
      "avg_ms": 0.177,
      "std_ms": 0.0034,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "29_Matmul_Mish_Mish",
      "status": "pass",
      "avg_ms": 0.1691,
      "std_ms": 0.0035,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "2_ConvTranspose2d_BiasAdd_Clamp_Scaling_Clamp_Divide",
      "status": "pass",
      "avg_ms": 8.1416,
      "std_ms": 0.0096,
      "output_shape": "((128, 64, 256, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "30_Gemm_GroupNorm_Hardtanh",
      "status": "pass",
      "avg_ms": 0.1607,
      "std_ms": 0.0039,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "31_Conv2d_Min_Add_Multiply",
      "status": "pass",
      "avg_ms": 2.1978,
      "std_ms": 0.0065,
      "output_shape": "((128, 128, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "32_Conv2d_Scaling_Min",
      "status": "pass",
      "avg_ms": 3.3657,
      "std_ms": 0.0092,
      "output_shape": "((64, 1, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "33_Gemm_Scale_BatchNorm",
      "status": "pass",
      "avg_ms": 0.1642,
      "std_ms": 0.0031,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "34_ConvTranspose3d_LayerNorm_GELU_Scaling",
      "status": "pass",
      "avg_ms": 21.6421,
      "std_ms": 0.3036,
      "output_shape": "((32, 64, 32, 64, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "35_Conv2d_Subtract_HardSwish_MaxPool_Mish",
      "status": "pass",
      "avg_ms": 3.056,
      "std_ms": 0.012,
      "output_shape": "((128, 128, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "36_ConvTranspose2d_Min_Sum_GELU_Add",
      "status": "pass",
      "avg_ms": 0.9578,
      "std_ms": 0.169,
      "output_shape": "((16, 1, 1, 256), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "37_Matmul_Swish_Sum_GroupNorm",
      "status": "pass",
      "avg_ms": 5.2208,
      "std_ms": 0.5499,
      "output_shape": "((32768, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "38_ConvTranspose3d_AvgPool_Clamp_Softmax_Multiply",
      "status": "pass",
      "avg_ms": 17.9207,
      "std_ms": 0.0242,
      "output_shape": "((32, 64, 32, 64, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "39_Gemm_Scale_BatchNorm",
      "status": "pass",
      "avg_ms": 2.3932,
      "std_ms": 0.3847,
      "output_shape": "((16384, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "3_ConvTranspose3d_Sum_LayerNorm_AvgPool_GELU",
      "status": "pass",
      "avg_ms": 9.2635,
      "std_ms": 0.1041,
      "output_shape": "((32, 64, 16, 32, 32), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "40_Matmul_Scaling_ResidualAdd",
      "status": "pass",
      "avg_ms": 1.2963,
      "std_ms": 0.0043,
      "output_shape": "((16384, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "41_Gemm_BatchNorm_GELU_ReLU",
      "status": "pass",
      "avg_ms": 2.3889,
      "std_ms": 0.3162,
      "output_shape": "((16384, 4096), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply",
      "status": "pass",
      "avg_ms": 3.5431,
      "std_ms": 0.007,
      "output_shape": "((16, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "43_Conv3d_Max_LogSumExp_ReLU",
      "status": "pass",
      "avg_ms": 5.0491,
      "std_ms": 0.0494,
      "output_shape": "((4, 1, 16, 64, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean",
      "status": "pass",
      "avg_ms": 0.3935,
      "std_ms": 0.0031,
      "output_shape": "((16, 128, 1, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "45_Gemm_Sigmoid_LogSumExp",
      "status": "pass",
      "avg_ms": 0.7808,
      "std_ms": 0.0069,
      "output_shape": "((16384,), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "46_Conv2d_Subtract_Tanh_Subtract_AvgPool",
      "status": "pass",
      "avg_ms": 3.0467,
      "std_ms": 0.0126,
      "output_shape": "((128, 128, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "47_Conv3d_Mish_Tanh",
      "status": "pass",
      "avg_ms": 2.6173,
      "std_ms": 0.0042,
      "output_shape": "((16, 64, 30, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "48_Conv3d_Scaling_Tanh_Multiply_Sigmoid",
      "status": "pass",
      "avg_ms": 0.7212,
      "std_ms": 0.0053,
      "output_shape": "((128, 16, 14, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "49_ConvTranspose3d_Softmax_Sigmoid",
      "status": "pass",
      "avg_ms": 5.5015,
      "std_ms": 0.7993,
      "output_shape": "((16, 64, 32, 64, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "4_Conv2d_Mish_Mish",
      "status": "pass",
      "avg_ms": 5.5177,
      "std_ms": 0.0082,
      "output_shape": "((64, 128, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "50_ConvTranspose3d_Scaling_AvgPool_BiasAdd_Scaling",
      "status": "pass",
      "avg_ms": 1.8653,
      "std_ms": 0.4812,
      "output_shape": "((128, 16, 15, 31, 31), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd",
      "status": "pass",
      "avg_ms": 0.2404,
      "std_ms": 0.0067,
      "output_shape": "((2048, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "52_Conv2d_Activation_BatchNorm",
      "status": "pass",
      "avg_ms": 2.5654,
      "std_ms": 0.274,
      "output_shape": "((64, 128, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "53_Gemm_Scaling_Hardtanh_GELU",
      "status": "pass",
      "avg_ms": 0.2318,
      "std_ms": 0.0023,
      "output_shape": "((2048, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "54_Conv2d_Multiply_LeakyReLU_GELU",
      "status": "pass",
      "avg_ms": 9.0919,
      "std_ms": 0.2205,
      "output_shape": "((64, 64, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "55_Matmul_MaxPool_Sum_Scale",
      "status": "pass",
      "avg_ms": 0.1697,
      "std_ms": 0.004,
      "output_shape": "((128,), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "56_Matmul_Sigmoid_Sum",
      "status": "pass",
      "avg_ms": 0.1197,
      "std_ms": 0.0045,
      "output_shape": "((128, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "57_Conv2d_ReLU_HardSwish",
      "status": "pass",
      "avg_ms": 0.5468,
      "std_ms": 0.0053,
      "output_shape": "((128, 64, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp",
      "status": "pass",
      "avg_ms": 3.0235,
      "std_ms": 0.5507,
      "output_shape": "((128, 1, 31, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "59_Matmul_Swish_Scaling",
      "status": "pass",
      "avg_ms": 0.1328,
      "std_ms": 0.0041,
      "output_shape": "((128, 32768), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "5_ConvTranspose2d_Subtract_Tanh",
      "status": "pass",
      "avg_ms": 11.1211,
      "std_ms": 0.1363,
      "output_shape": "((32, 64, 513, 513), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "60_ConvTranspose3d_Swish_GroupNorm_HardSwish",
      "status": "pass",
      "avg_ms": 5.1876,
      "std_ms": 0.0107,
      "output_shape": "((128, 16, 31, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "61_ConvTranspose3d_ReLU_GroupNorm",
      "status": "pass",
      "avg_ms": 1.7422,
      "std_ms": 0.0034,
      "output_shape": "((16, 128, 34, 34, 34), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "62_Matmul_GroupNorm_LeakyReLU_Sum",
      "status": "pass",
      "avg_ms": 0.1548,
      "std_ms": 0.0087,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "63_Gemm_ReLU_Divide",
      "status": "pass",
      "avg_ms": 0.1584,
      "std_ms": 0.0052,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU",
      "status": "pass",
      "avg_ms": 0.1318,
      "std_ms": 0.0042,
      "output_shape": "((1024, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "65_Conv2d_AvgPool_Sigmoid_Sum",
      "status": "pass",
      "avg_ms": 9.5015,
      "std_ms": 0.4722,
      "output_shape": "((128,), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "66_Matmul_Dropout_Softmax",
      "status": "pass",
      "avg_ms": 0.1229,
      "std_ms": 0.0059,
      "output_shape": "((128, 16384), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "67_Conv2d_GELU_GlobalAvgPool",
      "status": "pass",
      "avg_ms": 1.4309,
      "std_ms": 0.0048,
      "output_shape": "((128, 64), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "68_Matmul_Min_Subtract",
      "status": "pass",
      "avg_ms": 0.1175,
      "std_ms": 0.0059,
      "output_shape": "((128, 16384), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "69_Conv2d_HardSwish_ReLU",
      "status": "pass",
      "avg_ms": 0.5532,
      "std_ms": 0.0036,
      "output_shape": "((128, 64, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "6_Conv3d_Softmax_MaxPool_MaxPool",
      "status": "pass",
      "avg_ms": 0.5372,
      "std_ms": 0.0043,
      "output_shape": "((128, 16, 3, 7, 7), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "70_Gemm_Sigmoid_Scaling_ResidualAdd",
      "status": "pass",
      "avg_ms": 0.1564,
      "std_ms": 0.0036,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "71_Conv2d_Divide_LeakyReLU",
      "status": "pass",
      "avg_ms": 0.5485,
      "std_ms": 0.0069,
      "output_shape": "((128, 64, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "72_ConvTranspose3d_BatchNorm_AvgPool_AvgPool",
      "status": "pass",
      "avg_ms": 3.8343,
      "std_ms": 0.8618,
      "output_shape": "((64, 16, 15, 15, 15), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "73_Conv2d_BatchNorm_Scaling",
      "status": "pass",
      "avg_ms": 1.8927,
      "std_ms": 0.3177,
      "output_shape": "((128, 64, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "74_ConvTranspose3d_LeakyReLU_Multiply_LeakyReLU_Max",
      "status": "pass",
      "avg_ms": 3.5487,
      "std_ms": 0.6945,
      "output_shape": "((16, 32, 16, 32, 32), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "75_Gemm_GroupNorm_Min_BiasAdd",
      "status": "pass",
      "avg_ms": 0.2803,
      "std_ms": 0.0043,
      "output_shape": "((1, 8192, 1024, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "76_Gemm_Add_ReLU",
      "status": "pass",
      "avg_ms": 0.1646,
      "std_ms": 0.0051,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "77_ConvTranspose3d_Scale_BatchNorm_GlobalAvgPool",
      "status": "pass",
      "avg_ms": 2.4087,
      "std_ms": 0.0052,
      "output_shape": "((16, 128, 1, 1, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "78_ConvTranspose3d_Max_Max_Sum",
      "status": "pass",
      "avg_ms": 6.7567,
      "std_ms": 0.005,
      "output_shape": "((16, 1, 10, 10, 10), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max",
      "status": "pass",
      "avg_ms": 0.3917,
      "std_ms": 0.0025,
      "output_shape": "((128, 14, 30, 30), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd",
      "status": "pass",
      "avg_ms": 2.6196,
      "std_ms": 0.4465,
      "output_shape": "((64, 32, 30, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "80_Gemm_Max_Subtract_GELU",
      "status": "pass",
      "avg_ms": 0.1621,
      "std_ms": 0.005,
      "output_shape": "((1024, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp",
      "status": "pass",
      "avg_ms": 0.1636,
      "std_ms": 0.0057,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "82_Conv2d_Tanh_Scaling_BiasAdd_Max",
      "status": "pass",
      "avg_ms": 4.0803,
      "std_ms": 0.32,
      "output_shape": "((128, 64, 63, 63), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "83_Conv3d_GroupNorm_Min_Clamp_Dropout",
      "status": "pass",
      "avg_ms": 1.7088,
      "std_ms": 0.1253,
      "output_shape": "((128, 16, 14, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "84_Gemm_BatchNorm_Scaling_Softmax",
      "status": "pass",
      "avg_ms": 0.161,
      "std_ms": 0.0051,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "85_Conv2d_GroupNorm_Scale_MaxPool_Clamp",
      "status": "pass",
      "avg_ms": 1.3798,
      "std_ms": 0.2952,
      "output_shape": "((128, 64, 31, 31), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "86_Matmul_Divide_GELU",
      "status": "pass",
      "avg_ms": 0.1635,
      "std_ms": 0.0016,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "87_Conv2d_Subtract_Subtract_Mish",
      "status": "pass",
      "avg_ms": 2.1991,
      "std_ms": 0.0089,
      "output_shape": "((128, 64, 254, 254), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "88_Gemm_GroupNorm_Swish_Multiply_Swish",
      "status": "pass",
      "avg_ms": 0.1604,
      "std_ms": 0.0069,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max",
      "status": "pass",
      "avg_ms": 2.6379,
      "std_ms": 0.532,
      "output_shape": "((128, 16, 32, 32), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "8_Conv3d_Divide_Max_GlobalAvgPool_BiasAdd_Sum",
      "status": "pass",
      "avg_ms": 1.2935,
      "std_ms": 0.2914,
      "output_shape": "((128, 1, 1, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "90_Conv3d_LeakyReLU_Sum_Clamp_GELU",
      "status": "pass",
      "avg_ms": 2.1238,
      "std_ms": 0.3437,
      "output_shape": "((128, 64, 14, 62, 62), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "91_ConvTranspose2d_Softmax_BiasAdd_Scaling_Sigmoid",
      "status": "pass",
      "avg_ms": 5.8677,
      "std_ms": 0.7464,
      "output_shape": "((128, 128, 129, 129), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp",
      "status": "pass",
      "avg_ms": 1.946,
      "std_ms": 0.448,
      "output_shape": "((128, 1, 126, 126), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "93_ConvTranspose2d_Add_Min_GELU_Multiply",
      "status": "pass",
      "avg_ms": 1.2504,
      "std_ms": 0.0014,
      "output_shape": "((128, 128, 130, 130), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm",
      "status": "pass",
      "avg_ms": 0.1968,
      "std_ms": 0.0064,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh",
      "status": "pass",
      "avg_ms": 0.1589,
      "std_ms": 0.0041,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "96_ConvTranspose3d_Multiply_Max_GlobalAvgPool_Clamp",
      "status": "pass",
      "avg_ms": 1.7663,
      "std_ms": 0.4064,
      "output_shape": "((128, 16, 1, 1, 1), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "97_Matmul_BatchNorm_BiasAdd_Divide_Swish",
      "status": "pass",
      "avg_ms": 0.1593,
      "std_ms": 0.0043,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "98_Matmul_AvgPool_GELU_Scale_Max",
      "status": "pass",
      "avg_ms": 0.1547,
      "std_ms": 0.0029,
      "output_shape": "((1024,), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "99_Matmul_GELU_Softmax",
      "status": "pass",
      "avg_ms": 0.1606,
      "std_ms": 0.0046,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "9_Matmul_Subtract_Multiply_ReLU",
      "status": "pass",
      "avg_ms": 0.1693,
      "std_ms": 0.0042,
      "output_shape": "((1024, 8192), dtype('float32'))",
      "suite": "jaxkernelbench",
      "level": "level2"
    },
    {
      "name": "cross_entropy",
      "status": "pass",
      "avg_ms": 7.6771,
      "std_ms": 0.0124,
      "output_shape": "((), dtype('float32'))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "flash_attention",
      "status": "pass",
      "avg_ms": 1.4583,
      "std_ms": 0.2049,
      "output_shape": "((1, 64, 2048, 128), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "flex_attention",
      "status": "pass",
      "avg_ms": 2.866,
      "std_ms": 0.4981,
      "output_shape": "((1, 64, 2048, 128), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "gemm",
      "status": "pass",
      "avg_ms": 5.4526,
      "std_ms": 0.0059,
      "output_shape": "((8192, 28672), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "gqa_attention",
      "status": "pass",
      "avg_ms": 3.2546,
      "std_ms": 0.0092,
      "output_shape": "((1, 2048, 128, 128), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "mamba2_ssd",
      "status": "pass",
      "avg_ms": 1.9016,
      "std_ms": 0.5428,
      "output_shape": "((1, 64, 2048, 64), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "megablox_gmm",
      "status": "fail",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "r_update\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "mla_attention",
      "status": "pass",
      "avg_ms": 4.4812,
      "std_ms": 0.0832,
      "output_shape": "((1, 2048, 7168), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "paged_attention",
      "status": "pass",
      "avg_ms": 1.9649,
      "std_ms": 0.0037,
      "output_shape": "((32, 64, 128), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "ragged_dot",
      "status": "pass",
      "avg_ms": 1.3652,
      "std_ms": 0.0045,
      "output_shape": "((8, 1024, 14336), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "ragged_paged_attention",
      "status": "fail",
      "error": "The __index__() method was called on traced array with shape int32[]\nThe error occurred while tracing the function workload at /tmp/pallas_eval/baselines/ragged_paged_attention.py:60 for jit. This concrete value was not available in Python because it depends on the value of the argument num_seqs.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError",
      "traceback": "aged_attention.py\", line 74, in workload\n    for i in range(num_seqs[0]):\njax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]\nThe error occurred while tracing the function workload at /tmp/pallas_eval/baselines/ragged_paged_attention.py:60 for jit. This concrete value was not available in Python because it depends on the value of the argument num_seqs.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError\n",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "retnet_retention",
      "status": "pass",
      "avg_ms": 0.5177,
      "std_ms": 0.0029,
      "output_shape": "((1, 16, 2048, 256), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "rms_norm",
      "status": "pass",
      "avg_ms": 0.178,
      "std_ms": 0.0047,
      "output_shape": "((1, 2048, 8192), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "sparse_attention",
      "status": "pass",
      "avg_ms": 1.4815,
      "std_ms": 0.2145,
      "output_shape": "((64, 2048, 128), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "sparse_moe",
      "status": "pass",
      "avg_ms": 8.26,
      "std_ms": 0.0094,
      "output_shape": "((1, 2048, 4096), dtype('float32'))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "swiglu_mlp",
      "status": "pass",
      "avg_ms": 4.0585,
      "std_ms": 0.0075,
      "output_shape": "((1, 2048, 8192), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    },
    {
      "name": "triangle_multiplication",
      "status": "pass",
      "avg_ms": 1.3116,
      "std_ms": 0.0066,
      "output_shape": "((768, 768, 64), dtype(bfloat16))",
      "suite": "priority_kernels",
      "level": null
    }
  ]
}