{
  "metadata": {
    "total": 217,
    "correct": 68,
    "faster": 3,
    "errors": 139
  },
  "results": [
    {
      "name": "100_HingeLoss",
      "correct": true,
      "max_diff": 2e-06,
      "correctness_reason": "ok",
      "original_ms": 3.9405,
      "generated_ms": 177.4364,
      "speedup": 0.022,
      "original_std_ms": 1.2211,
      "generated_std_ms": 0.0166,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.5
    },
    {
      "name": "10_3D_tensor_matrix_multiplication",
      "correct": true,
      "max_diff": 0.000183,
      "correctness_reason": "ok",
      "original_ms": 0.2426,
      "generated_ms": 1.0183,
      "speedup": 0.238,
      "original_std_ms": 0.0041,
      "generated_std_ms": 0.0081,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.2
    },
    {
      "name": "11_4D_tensor_matrix_multiplication",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/11_4D_tensor_matrix_multiplication_gpt53.py:27 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/11_4D_tensor_matrix_multiplication_gpt53.py:27 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.6
    },
    {
      "name": "12_Matmul_with_diagonal_matrices_",
      "status": "error",
      "error": "repeat() missing 1 required positional argument: 'axis'",
      "traceback": "las/pallas_call.py\", line 1210, in _trace_kernel_to_jaxpr\n    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/primitives.py\", line 874, in wrap_with_transforms\n    return f(*new_args)\n  File \"/tmp/pallas_eval/generated/12_Matmul_with_diagonal_matrices__gpt53.py\", line 26, in kernel\n    a_broadcast = pltpu.repeat(a, (1, b.shape[1]))\nTypeError: repeat() missing 1 required positional argument: 'axis'\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 6.8
    },
    {
      "name": "13_Matmul_for_symmetric_matrices",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/13_Matmul_for_symmetric_matrices_gpt53.py:20 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "m_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/13_Matmul_for_symmetric_matrices_gpt53.py:20 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.9
    },
    {
      "name": "14_Matmul_for_upper_triangular_matrices",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/14_Matmul_for_upper_triangular_matrices_gpt53.py:26 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/14_Matmul_for_upper_triangular_matrices_gpt53.py:26 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.7
    },
    {
      "name": "15_Matmul_for_lower_triangular_matrices",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 195.78M of 128.00M vmem. Exceeded vmem capacity by 67.78M.\n\nProgram vmem requirement 195.78M:\n    scoped          195.78M\n\n  Largest program allocations in vmem:\n\n  1. Size: 64.00M\n     Operator: op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/15_Matmul_for_lower_triangular_matrices_gpt53.py\" source_line=30\n     Shape: u8[67108864]{0}\n     Unpadded size: 64.00M\n     XLA label: main.1 = custom-call(Arg_0.1, Arg_1.2), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[4096,4096]{1,0}, f32[4096,4096]{1,0}}\n     Allocation type: scoped\n     ==========================\n\n  2. Size: 64.00M\n     Operator: op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/15_Matmul_for_lower_triangular_matrices_gpt53.py\" source_line=30\n     Shape: u8[67108864]{0}\n     Unpadded size: 64.00M\n     XLA label: main.1 = custom-call(Arg_0.1, Arg_1.2), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[4096,4096]{1,0}, f32[4096,4096]{1,0}}\n     Allocation type: scoped\n     ==========================\n\n  3. Size: 64.00M\n     Operator: op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/15_Matmul_for_lower_triangular_matrices_gpt53.py\" source_line=30\n     Shape: u8[67108864]{0}\n     Unpadded size: 64.00M\n     XLA label: main.1 = custom-call(Arg_0.1, Arg_1.2), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[4096,4096]{1,0}, f32[4096,4096]{1,0}}\n     Allocation type: scoped\n     ==========================\n\n  4. Size: 3.78M\n     XLA label: register allocator spill slots call depth 2\n     Allocation type: scoped\n     ==========================\n\n",
      "traceback": "val/generated/15_Matmul_for_lower_triangular_matrices_gpt53.py\" source_line=30\n     Shape: u8[67108864]{0}\n     Unpadded size: 64.00M\n     XLA label: main.1 = custom-call(Arg_0.1, Arg_1.2), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[4096,4096]{1,0}, f32[4096,4096]{1,0}}\n     Allocation type: scoped\n     ==========================\n\n  4. Size: 3.78M\n     XLA label: register allocator spill slots call depth 2\n     Allocation type: scoped\n     ==========================\n\n\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 55.1
    },
    {
      "name": "16_Matmul_with_transposed_A",
      "status": "error",
      "error": "Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0 (set JAX_PLATFORMS='' to automatically choose an available backend)",
      "traceback": "-packages/jax/_src/core.py\", line 1066, in process_primitive\n    return primitive.impl(*args, **params)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/dispatch.py\", line 91, in apply_primitive\n    outs = fun(*args)\nRuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0 (set JAX_PLATFORMS='' to automatically choose an available backend)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 1.1
    },
    {
      "name": "17_Matmul_with_transposed_B",
      "status": "error",
      "error": "Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0 (set JAX_PLATFORMS='' to automatically choose an available backend)",
      "traceback": "-packages/jax/_src/core.py\", line 1066, in process_primitive\n    return primitive.impl(*args, **params)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/dispatch.py\", line 91, in apply_primitive\n    outs = fun(*args)\nRuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/vfio/0): Device or resource busy: Device or resource busy; Couldn't open iommu group /dev/vfio/0 (set JAX_PLATFORMS='' to automatically choose an available backend)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 1.1
    },
    {
      "name": "18_Matmul_with_transposed_both",
      "status": "error",
      "error": "Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by another process probably owned by another user. Run \"$ sudo lsof -w /dev/vfio/0\" to figure out which process is using the TPU. If you still get this message, run \"$ sudo rm /tmp/libtpu_lockfile\". (set JAX_PLATFORMS='' to automatically choose an available backend)",
      "traceback": "s, **params)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/dispatch.py\", line 91, in apply_primitive\n    outs = fun(*args)\nRuntimeError: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by another process probably owned by another user. Run \"$ sudo lsof -w /dev/vfio/0\" to figure out which process is using the TPU. If you still get this message, run \"$ sudo rm /tmp/libtpu_lockfile\". (set JAX_PLATFORMS='' to automatically choose an available backend)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 1.4
    },
    {
      "name": "19_ReLU",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 23.1
    },
    {
      "name": "1_Square_matrix_multiplication_",
      "correct": true,
      "max_diff": 0.000366,
      "correctness_reason": "ok",
      "original_ms": 0.2825,
      "generated_ms": 2.1197,
      "speedup": 0.133,
      "original_std_ms": 0.014,
      "generated_std_ms": 0.6487,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.4
    },
    {
      "name": "20_LeakyReLU",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 26.4
    },
    {
      "name": "21_Sigmoid",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 18.6
    },
    {
      "name": "22_Tanh",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 17.4
    },
    {
      "name": "23_Softmax",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call softmax_kernel at /tmp/pallas_eval/generated/23_Softmax_gpt53.py:13 has block shape (Blocked(block_size=1), Blocked(block_size=393216)), array shape (4096, 393216), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "ns of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call softmax_kernel at /tmp/pallas_eval/generated/23_Softmax_gpt53.py:13 has block shape (Blocked(block_size=1), Blocked(block_size=393216)), array shape (4096, 393216), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.1
    },
    {
      "name": "24_LogSoftmax",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/24_LogSoftmax_gpt53.py:14 has block shape (Blocked(block_size=1), Blocked(block_size=393216)), array shape (4096, 393216), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "ions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/24_LogSoftmax_gpt53.py:14 has block shape (Blocked(block_size=1), Blocked(block_size=393216)), array shape (4096, 393216), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.9
    },
    {
      "name": "25_Swish",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 17.8
    },
    {
      "name": "26_GELU_",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 24.4
    },
    {
      "name": "27_SELU_",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 18.1
    },
    {
      "name": "28_HardSigmoid",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 17.6
    },
    {
      "name": "29_Softplus",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 18.0
    },
    {
      "name": "2_Standard_matrix_multiplication_",
      "correct": true,
      "max_diff": 0.000854,
      "correctness_reason": "ok",
      "original_ms": 0.2988,
      "generated_ms": 2.0204,
      "speedup": 0.148,
      "original_std_ms": 0.0036,
      "generated_std_ms": 0.6486,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.5
    },
    {
      "name": "30_Softsign",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 19.4
    },
    {
      "name": "31_ELU",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 18.1
    },
    {
      "name": "32_HardTanh",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)",
      "traceback": "s_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\nValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 6.00G. That was not possible. There are 1.25G free.; (0x0x0_HBM0)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 23.4
    },
    {
      "name": "33_BatchNorm",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 16.00G at the bottom of memory. That was not possible. There are 11.25G free, 0B reserved, and 11.25G reservable.",
      "traceback": " eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 102, in eval_jaxkernelbench\n    gen_times, _ = benchmark_fn(gen_model.forward, inputs)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 41, in benchmark_fn\n    out = jitted(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 16.00G at the bottom of memory. That was not possible. There are 11.25G free, 0B reserved, and 11.25G reservable.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 16.3
    },
    {
      "name": "34_InstanceNorm",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/34_InstanceNorm_gpt53.py:24 has block shape (Blocked(block_size=1), Blocked(block_size=65536)), array shape (1024, 65536), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "ions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/34_InstanceNorm_gpt53.py:24 has block shape (Blocked(block_size=1), Blocked(block_size=65536)), array shape (1024, 65536), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.4
    },
    {
      "name": "35_GroupNorm_",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/35_GroupNorm__gpt53.py:25 has block shape (Blocked(block_size=1), Blocked(block_size=524288)), array shape (128, 524288), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "sions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/35_GroupNorm__gpt53.py:25 has block shape (Blocked(block_size=1), Blocked(block_size=524288)), array shape (128, 524288), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.5
    },
    {
      "name": "36_RMSNorm_",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Ran out of memory in memory space vmem while allocating on stack for %main.1 = f32[112,64,512,512]{3,2,1,0:T(8,128)} custom-call(%Arg_0.1), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[112,64,512,512]{3,2,1,0}}, metadata={op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/36_RMSNorm__gpt53.py\" source_line=26}. Scoped allocation with size 64.00M and limit 32.00M exceeded scoped vmem limit by 32.00M. It should not be possible to run out of scoped vmem - please file a bug against XLA.",
      "traceback": "llocating on stack for %main.1 = f32[112,64,512,512]{3,2,1,0:T(8,128)} custom-call(%Arg_0.1), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[112,64,512,512]{3,2,1,0}}, metadata={op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/36_RMSNorm__gpt53.py\" source_line=26}. Scoped allocation with size 64.00M and limit 32.00M exceeded scoped vmem limit by 32.00M. It should not be possible to run out of scoped vmem - please file a bug against XLA.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.0
    },
    {
      "name": "37_FrobeniusNorm_",
      "status": "error",
      "error": "Invalid shape for `swap`. Ref shape: (1, 1). Expected shape: (1, 1). Value shape: (). Transforms: (NDIndexer(indices=(Slice(start=0, size=1, stride=1), Slice(start=0, size=1, stride=1)), shape=(1, 1), int_indexer_shape=(), validate=False),). ",
      "traceback": "orm__gpt53.py\", line 21, in sumsq_kernel\n    o_ref[...] = jnp.sum(val * val)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\nValueError: Invalid shape for `swap`. Ref shape: (1, 1). Expected shape: (1, 1). Value shape: (). Transforms: (NDIndexer(indices=(Slice(start=0, size=1, stride=1), Slice(start=0, size=1, stride=1)), shape=(1, 1), int_indexer_shape=(), validate=False),). \n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.6
    },
    {
      "name": "38_L1Norm_",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call l1_norm_kernel at /tmp/pallas_eval/generated/38_L1Norm__gpt53.py:12 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (4096, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "nsions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call l1_norm_kernel at /tmp/pallas_eval/generated/38_L1Norm__gpt53.py:12 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (4096, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 6.8
    },
    {
      "name": "39_L2Norm_",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/39_L2Norm__gpt53.py:20 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (4096, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/39_L2Norm__gpt53.py:20 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (4096, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 6.5
    },
    {
      "name": "3_Batched_matrix_multiplication",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=128), Blocked(block_size=1024))) must have the same number of dimensions as the array shape (128, 512, 1024).",
      "traceback": "pings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=128), Blocked(block_size=1024))) must have the same number of dimensions as the array shape (128, 512, 1024).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.5
    },
    {
      "name": "40_LayerNorm",
      "status": "error",
      "error": "Incompatible shapes for broadcasting: shapes=[(1, 64, 256, 256), (64, 0, 256)]",
      "traceback": "y\", line 1279, in multiply\n    x, y = promote_args(\"multiply\", x, y)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 228, in promote_args\n    return promote_shapes(fun_name, *promote_dtypes(*args))\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 64, in promote_shapes\n    result_rank = len(lax.broadcast_shapes(*shapes))\nValueError: Incompatible shapes for broadcasting: shapes=[(1, 64, 256, 256), (64, 0, 256)]\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.3
    },
    {
      "name": "41_Max_Pooling_1D",
      "status": "error",
      "error": "Shape mismatch in input, indices and output",
      "traceback": "xpr_subcomp(lowering_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2349, in _gather_lowering_rule\n    raise ValueError(\"Shape mismatch in input, indices and output\")\nValueError: Shape mismatch in input, indices and output\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.8
    },
    {
      "name": "42_Max_Pooling_2D",
      "status": "error",
      "error": "Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]\nIt arose in the jnp.arange argument 'start'\nThe error occurred while tracing the function kernel at /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:40 for pallas_call kernel. This value became a tracer due to JAX operations on these lines:\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:44 (kernel)\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:45 (kernel)\n\n  operation a:i32[] = add b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:47 (kernel)\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError",
      "traceback": "value became a tracer due to JAX operations on these lines:\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:44 (kernel)\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:45 (kernel)\n\n  operation a:i32[] = add b 128:i32[]\n    from line /tmp/pallas_eval/generated/42_Max_Pooling_2D_gpt53.py:47 (kernel)\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.9
    },
    {
      "name": "43_Max_Pooling_3D",
      "status": "error",
      "error": "Only 2D gather is supported",
      "traceback": "osaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2345, in _gather_lowering_rule\n    raise NotImplementedError(\"Only 2D gather is supported\")\nNotImplementedError: Only 2D gather is supported\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.6
    },
    {
      "name": "44_Average_Pooling_1D",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: cumsum. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "me/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 3199, in _pjit_lowering_rule\n    return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: cumsum. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.2
    },
    {
      "name": "45_Average_Pooling_2D",
      "status": "error",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<int32[]>with<DynamicJaxprTrace>, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "ocal/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<int32[]>with<DynamicJaxprTrace>, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.5
    },
    {
      "name": "46_Average_Pooling_3D",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: reduce_window_sum. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "unc_args, **func_kwargs)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: reduce_window_sum. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.9
    },
    {
      "name": "47_Sum_reduction_over_a_dimension",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Ran out of memory in memory space vmem while allocating on stack for %main.1 = f32[128,1,4095]{2,1,0:T(1,128)S(1)} custom-call(%copy), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[128,4096,4095]{2,1,0}}, metadata={op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/47_Sum_reduction_over_a_dimension_gpt53.py\" source_line=29}. Scoped allocation with size 64.07M and limit 32.00M exceeded scoped vmem limit by 32.07M. It should not be possible to run out of scoped vmem - please file a bug against XLA.",
      "traceback": "tack for %main.1 = f32[128,1,4095]{2,1,0:T(1,128)S(1)} custom-call(%copy), custom_call_target=\"tpu_custom_call\", operand_layout_constraints={f32[128,4096,4095]{2,1,0}}, metadata={op_name=\"jit(forward)/jit(main)/pallas_call\" source_file=\"/tmp/pallas_eval/generated/47_Sum_reduction_over_a_dimension_gpt53.py\" source_line=29}. Scoped allocation with size 64.07M and limit 32.00M exceeded scoped vmem limit by 32.07M. It should not be possible to run out of scoped vmem - please file a bug against XLA.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.8
    },
    {
      "name": "48_Mean_reduction_over_a_dimension",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/48_Mean_reduction_over_a_dimension_gpt53.py:19 has block shape (Blocked(block_size=128), Blocked(block_size=4096), Blocked(block_size=1)), array shape (128, 4096, 4095), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[], b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "pectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/48_Mean_reduction_over_a_dimension_gpt53.py:19 has block shape (Blocked(block_size=128), Blocked(block_size=4096), Blocked(block_size=1)), array shape (128, 4096, 4095), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[], b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.9
    },
    {
      "name": "49_Max_reduction_over_a_dimension",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/49_Max_reduction_over_a_dimension_gpt53.py:23 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/49_Max_reduction_over_a_dimension_gpt53.py:23 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.7
    },
    {
      "name": "4_Matrix_vector_multiplication_",
      "correct": false,
      "max_diff": 262439.21875,
      "correctness_reason": "values differ",
      "original_ms": 7.8103,
      "generated_ms": 16.3223,
      "speedup": 0.479,
      "original_std_ms": 0.8232,
      "generated_std_ms": 0.2012,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.0
    },
    {
      "name": "50_conv_standard_2D__square_input__square_kernel",
      "status": "error",
      "error": "transpose requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.",
      "traceback": "py\", line 1197, in transpose\n    a = util.ensure_arraylike(\"transpose\", a)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 155, in ensure_arraylike\n    check_arraylike(fun_name, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 181, in check_arraylike\n    raise TypeError(msg.format(fun_name, type(arg), pos))\nTypeError: transpose requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 7.8
    },
    {
      "name": "51_Argmax_over_a_dimension",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: argmax. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "lues = f(*func_args, **func_kwargs)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: argmax. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.0
    },
    {
      "name": "52_Argmin_over_a_dimension",
      "status": "error",
      "error": "",
      "traceback": "\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\n  File \"/tmp/pallas_eval/generated/52_Argmin_over_a_dimension_gpt53.py\", line 34, in forward\n    assert D2 % b2 == 0\nAssertionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.9
    },
    {
      "name": "53_Min_reduction_over_a_dimension",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call min_reduce_kernel at /tmp/pallas_eval/generated/53_Min_reduction_over_a_dimension_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=4096), Blocked(block_size=45)), array shape (128, 4096, 4095), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[], b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "y, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call min_reduce_kernel at /tmp/pallas_eval/generated/53_Min_reduction_over_a_dimension_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=4096), Blocked(block_size=45)), array shape (128, 4096, 4095), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[], b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.8
    },
    {
      "name": "54_conv_standard_3D__square_input__square_kernel",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: conv_general_dilated. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "_args, **func_kwargs)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: conv_general_dilated. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.1
    },
    {
      "name": "55_conv_standard_2D__asymmetric_input__square_kernel",
      "status": "error",
      "error": "'NoneType' object has no attribute 'shape'",
      "traceback": "cent call last):\n  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 91, in eval_jaxkernelbench\n    ref_out = jax.jit(orig_model.forward)(*inputs)\n  File \"/tmp/pallas_eval/originals/55_conv_standard_2D__asymmetric_input__square_kernel_original.py\", line 41, in forward\n    out = jax.lax.conv_general_dilated(\nAttributeError: 'NoneType' object has no attribute 'shape'\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.3
    },
    {
      "name": "56_conv_standard_2D__asymmetric_input__asymmetric_kernel",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 1.8668,
      "generated_ms": 5.4921,
      "speedup": 0.34,
      "original_std_ms": 0.0044,
      "generated_std_ms": 0.008,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 90.2
    },
    {
      "name": "57_conv_transposed_2D__square_input__square_kernel",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2983, in _run_body\n    args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 17.5
    },
    {
      "name": "58_conv_transposed_3D__asymmetric_input__asymmetric_kernel",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.4916,
      "generated_ms": 8.0219,
      "speedup": 0.435,
      "original_std_ms": 0.0066,
      "generated_std_ms": 0.012,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 264.4
    },
    {
      "name": "59_conv_standard_3D__asymmetric_input__square_kernel",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel at /tmp/pallas_eval/generated/59_conv_standard_3D__asymmetric_input__square_kernel_gpt53.py:52 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (10322560, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "isible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel at /tmp/pallas_eval/generated/59_conv_standard_3D__asymmetric_input__square_kernel_gpt53.py:52 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (10322560, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.3
    },
    {
      "name": "5_Matrix_scalar_multiplication",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 7.2994,
      "generated_ms": 31.0112,
      "speedup": 0.235,
      "original_std_ms": 0.4469,
      "generated_std_ms": 0.0179,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 19.5
    },
    {
      "name": "60_conv_standard_3D__square_input__asymmetric_kernel",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/60_conv_standard_3D__square_input__asymmetric_kernel_gpt53.py:54 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (3452160, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "ible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/60_conv_standard_3D__square_input__asymmetric_kernel_gpt53.py:54 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (3452160, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 24.6
    },
    {
      "name": "61_conv_transposed_3D__square_input__square_kernel",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.6035,
      "generated_ms": 11.3411,
      "speedup": 0.494,
      "original_std_ms": 0.0077,
      "generated_std_ms": 1.5866,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 133.5
    },
    {
      "name": "62_conv_standard_2D__square_input__asymmetric_kernel",
      "status": "error",
      "error": "Invalid shape for `swap`. Ref shape: (1, 1, 1, 1). Expected shape: (). Value shape: (1,). Transforms: (NDIndexer(indices=(ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])), shape=(1, 1, 1, 1), int_indexer_shape=(), validate=False),). ",
      "traceback": "  o_ref[n, h, w, oc] = acc.astype(o_ref.dtype)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\nValueError: Invalid shape for `swap`. Ref shape: (1, 1, 1, 1). Expected shape: (). Value shape: (1,). Transforms: (NDIndexer(indices=(ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])), shape=(1, 1, 1, 1), int_indexer_shape=(), validate=False),). \n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 50.5
    },
    {
      "name": "63_conv_standard_2D__square_input__square_kernel",
      "status": "error",
      "error": "Cannot broadcast shapes for indexing: ((8,), (8,), (128,), ())",
      "traceback": "standard_2D__square_input__square_kernel_gpt53.py\", line 38, in conv_kernel\n    x_val = x_ref[r, h_in, w_in, c]\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/state/indexing.py\", line 255, in from_indices_shape\n    raise ValueError(\nValueError: Cannot broadcast shapes for indexing: ((8,), (8,), (128,), ())\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 22.5
    },
    {
      "name": "64_conv_transposed_1D",
      "status": "error",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<int32[]>with<DynamicJaxprTrace>, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "ocal/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<int32[]>with<DynamicJaxprTrace>, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.7
    },
    {
      "name": "65_conv_transposed_2D__square_input__asymmetric_kernel",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=681574400) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[681574400]{0}', space=vmem, size = 0x28a00000, tag = 'operand span for operand 0'] :: main.1",
      "traceback": ":\n  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=681574400) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[681574400]{0}', space=vmem, size = 0x28a00000, tag = 'operand span for operand 0'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 29.1
    },
    {
      "name": "66_conv_standard_3D__asymmetric_input__asymmetric_kernel",
      "status": "error",
      "error": "Incompatible shapes for broadcasting: shapes=[(128, 1, 3, 1), (1, 128, 3)]",
      "traceback": "cs.py\", line 1279, in multiply\n    x, y = promote_args(\"multiply\", x, y)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 228, in promote_args\n    return promote_shapes(fun_name, *promote_dtypes(*args))\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 64, in promote_shapes\n    result_rank = len(lax.broadcast_shapes(*shapes))\nValueError: Incompatible shapes for broadcasting: shapes=[(128, 1, 3, 1), (1, 128, 3)]\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 52.8
    },
    {
      "name": "67_conv_standard_1D",
      "status": "error",
      "error": "Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 2 but `inputs` is a tuple of length 9, so the lengths do not match",
      "traceback": "hand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1687, in wrapped\n    kernel_args, grid_mapping = pallas_core.get_grid_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 1137, in get_grid_mapping\n    raise ValueError(\nValueError: Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 2 but `inputs` is a tuple of length 9, so the lengths do not match\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.8
    },
    {
      "name": "68_conv_transposed_3D__square_input__asymmetric_kernel",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: scatter-add. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "= f(*func_args, **func_kwargs)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: scatter-add. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 111.2
    },
    {
      "name": "69_conv_transposed_2D__asymmetric_input__asymmetric_kernel",
      "status": "error",
      "error": "Incompatible shapes for broadcasting: shapes=[(128,), (64,), ()]",
      "traceback": "rc/numpy/lax_numpy.py\", line 2821, in where\n    return util._where(condition, x, y)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 311, in _where\n    condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 264, in _broadcast_arrays\n    result_shape = lax.broadcast_shapes(*shapes)\nValueError: Incompatible shapes for broadcasting: shapes=[(128,), (64,), ()]\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.1
    },
    {
      "name": "6_Matmul_with_large_K_dimension_",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=268435456) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[268435456]{0}', space=vmem, size = 0x10000000, tag = 'operand span for operand 0'] :: main.1",
      "traceback": ":\n  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=268435456) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[268435456]{0}', space=vmem, size = 0x10000000, tag = 'operand span for operand 0'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 15.8
    },
    {
      "name": "70_conv_transposed_3D__asymmetric_input__square_kernel",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "status": "error",
      "error": "ERROR: Command timed out"
    },
    {
      "name": "71_conv_transposed_2D__asymmetric_input__square_kernel",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 6.5041,
      "generated_ms": 12.4984,
      "speedup": 0.52,
      "original_std_ms": 0.1569,
      "generated_std_ms": 0.1853,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 56.1
    },
    {
      "name": "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.2737,
      "generated_ms": 19.5388,
      "speedup": 0.27,
      "original_std_ms": 0.0095,
      "generated_std_ms": 0.0147,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 173.6
    },
    {
      "name": "73_conv_transposed_3D_asymmetric_input_square_kernel__strided_padded__grouped",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 14.1088,
      "generated_ms": 35.6172,
      "speedup": 0.396,
      "original_std_ms": 0.0744,
      "generated_std_ms": 0.1176,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 108.8
    },
    {
      "name": "74_conv_transposed_1D_dilated",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body_c at /tmp/pallas_eval/generated/74_conv_transposed_1D_dilated_gpt53.py:46 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "om_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body_c at /tmp/pallas_eval/generated/74_conv_transposed_1D_dilated_gpt53.py:46 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.7
    },
    {
      "name": "75_conv_transposed_2D_asymmetric_input_asymmetric_kernel_strided__grouped____padded____dilated__",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 14.6406,
      "generated_ms": 130.1539,
      "speedup": 0.112,
      "original_std_ms": 0.2036,
      "generated_std_ms": 0.0433,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 29.1
    },
    {
      "name": "76_conv_standard_1D_dilated_strided__",
      "status": "error",
      "error": "dot_general requires contracting dimensions to have the same shape, got (64,) and (128,).",
      "traceback": ", in kernel\n    acc = jax.lax.fori_loop(0, K, body, acc)\n  File \"/tmp/pallas_eval/generated/76_conv_standard_1D_dilated_strided___gpt53.py\", line 74, in body\n    acc = acc + jnp.dot(x_vals, w_vals)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/tensor_contractions.py\", line 121, in dot\n    result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims),\nTypeError: dot_general requires contracting dimensions to have the same shape, got (64,) and (128,).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.8
    },
    {
      "name": "77_conv_transposed_3D_square_input_square_kernel___padded____dilated____strided__",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/77_conv_transposed_3D_square_input_square_kernel___padded____dilated____strided___gpt53.py:51 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2230800, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "y, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/77_conv_transposed_3D_square_input_square_kernel___padded____dilated____strided___gpt53.py:51 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2230800, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 25.0
    },
    {
      "name": "78_conv_transposed_2D_asymmetric_input_asymmetric_kernel___padded__",
      "status": "error",
      "error": "conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 32 // 1 != 3.",
      "traceback": "lbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 91, in eval_jaxkernelbench\n    ref_out = jax.jit(orig_model.forward)(*inputs)\n  File \"/tmp/pallas_eval/originals/78_conv_transposed_2D_asymmetric_input_asymmetric_kernel___padded___original.py\", line 44, in forward\n    out = lax.conv_transpose(\nValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 32 // 1 != 3.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.5
    },
    {
      "name": "79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated__",
      "status": "error",
      "error": "Index map function <lambda> at /tmp/pallas_eval/generated/79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated___gpt53.py:94 for args[1] must return 3 values to match block_shape=(Blocked(block_size=5), Blocked(block_size=64), Blocked(block_size=32)). Currently returning 2 values:",
      "traceback": "eturn block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 564, in to_block_mapping\n    raise ValueError(\nValueError: Index map function <lambda> at /tmp/pallas_eval/generated/79_conv_transposed_1D_asymmetric_input_square_kernel___padded____strided____dilated___gpt53.py:94 for args[1] must return 3 values to match block_shape=(Blocked(block_size=5), Blocked(block_size=64), Blocked(block_size=32)). Currently returning 2 values:\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.5
    },
    {
      "name": "7_Matmul_with_small_K_dimension_",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.294,
      "generated_ms": 21.8383,
      "speedup": 0.151,
      "original_std_ms": 0.4647,
      "generated_std_ms": 0.0431,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.3
    },
    {
      "name": "80_conv_standard_2D_square_input_asymmetric_kernel___dilated____padded__",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "status": "error",
      "error": "coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory\nE0401 08:44:51.597491  452220 coredump_hook.cc:457] RAW: Dumping core locally.\nF0401 08:44:51.587944  452220 array.h:411] Check failed: indexes.size() == num_dimensions() (0 vs. 3) \nE0401 08:44:51.734812  452220 process_state.cc:808] RAW: Raising signal 6 with default behavior"
    },
    {
      "name": "81_conv_transposed_2D_asymmetric_input_square_kernel___dilated____padded____strided__",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.4282,
      "generated_ms": 12.121,
      "speedup": 0.283,
      "original_std_ms": 0.2534,
      "generated_std_ms": 0.0392,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 21.4
    },
    {
      "name": "82_conv_depthwise_2D_square_input_square_kernel",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function kernel at /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:46 for pallas_call kernel. This value became a tracer due to JAX operations on these lines:\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:51 (kernel)\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:52 (kernel)\n\n  operation a:i32[] = add b 128:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n  operation a:i32[] = add b 3:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n  operation a:i32[] = sub b 1:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n(Additional originating lines are not shown.)\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "rated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n  operation a:i32[] = add b 3:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n  operation a:i32[] = sub b 1:i32[]\n    from line /tmp/pallas_eval/generated/82_conv_depthwise_2D_square_input_square_kernel_gpt53.py:54 (kernel)\n\n(Additional originating lines are not shown.)\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.2
    },
    {
      "name": "83_conv_depthwise_2D_square_input_asymmetric_kernel",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/83_conv_depthwise_2D_square_input_asymmetric_kernel_gpt53.py:52 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=8)), array shape (64, 510, 512, 8), and index_map { lambda ; a:i32[] b:i32[] c:i32[] d:i32[]. let  in (a, b, c, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/83_conv_depthwise_2D_square_input_asymmetric_kernel_gpt53.py:52 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=8)), array shape (64, 510, 512, 8), and index_map { lambda ; a:i32[] b:i32[] c:i32[] d:i32[]. let  in (a, b, c, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.6
    },
    {
      "name": "84_conv_depthwise_2D_asymmetric_input_square_kernel",
      "status": "error",
      "error": "`broadcast_to` is a Triton-specific primitive. Please consider using `jnp.broadcast_to` instead.",
      "traceback": "ing_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1913, in _broadcast_to_lowering_rule\n    raise RuntimeError(\nRuntimeError: `broadcast_to` is a Triton-specific primitive. Please consider using `jnp.broadcast_to` instead.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 15.7
    },
    {
      "name": "85_conv_depthwise_2D_asymmetric_input_asymmetric_kernel",
      "status": "error",
      "error": "Index map function <lambda> at /tmp/pallas_eval/generated/85_conv_depthwise_2D_asymmetric_input_asymmetric_kernel_gpt53.py:87 for args[0] must return 4 values to match block_shape=(Blocked(block_size=32), Blocked(block_size=128), Blocked(block_size=256), Blocked(block_size=128)). Currently returning 1 values:",
      "traceback": " return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 564, in to_block_mapping\n    raise ValueError(\nValueError: Index map function <lambda> at /tmp/pallas_eval/generated/85_conv_depthwise_2D_asymmetric_input_asymmetric_kernel_gpt53.py:87 for args[0] must return 4 values to match block_shape=(Blocked(block_size=32), Blocked(block_size=128), Blocked(block_size=256), Blocked(block_size=128)). Currently returning 1 values:\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.8
    },
    {
      "name": "86_conv_depthwise_separable_2D",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "status": "error",
      "error": "ERROR: Command timed out"
    },
    {
      "name": "87_conv_pointwise_2D",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/87_conv_pointwise_2D_gpt53.py:40 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "e 76, in from_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/87_conv_pointwise_2D_gpt53.py:40 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.9
    },
    {
      "name": "88_MinGPTNewGelu",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.4901,
      "generated_ms": 1.9933,
      "speedup": 0.246,
      "original_std_ms": 0.0044,
      "generated_std_ms": 0.0063,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.1
    },
    {
      "name": "89_cumsum",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call cumsum_kernel at /tmp/pallas_eval/generated/89_cumsum_gpt53.py:12 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "sions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call cumsum_kernel at /tmp/pallas_eval/generated/89_cumsum_gpt53.py:12 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.0
    },
    {
      "name": "8_Matmul_with_irregular_shapes_",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/8_Matmul_with_irregular_shapes__gpt53.py:14 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "om_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/8_Matmul_with_irregular_shapes__gpt53.py:14 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.0
    },
    {
      "name": "90_cumprod",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: cumprod. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "e/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 3199, in _pjit_lowering_rule\n    return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: cumprod. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 11.9
    },
    {
      "name": "91_cumsum_reverse",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: rev. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 3199, in _pjit_lowering_rule\n    return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: rev. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.1
    },
    {
      "name": "92_cumsum_exclusive",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/92_cumsum_exclusive_gpt53.py:14 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/92_cumsum_exclusive_gpt53.py:14 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 12.0
    },
    {
      "name": "93_masked_cumsum",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/93_masked_cumsum_gpt53.py:21 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "ns of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/93_masked_cumsum_gpt53.py:21 has block shape (Blocked(block_size=1), Blocked(block_size=32768)), array shape (32768, 32768), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 14.6
    },
    {
      "name": "94_MSELoss",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 7.8336,
      "generated_ms": 34.1846,
      "speedup": 0.229,
      "original_std_ms": 0.8427,
      "generated_std_ms": 0.0217,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.5
    },
    {
      "name": "95_CrossEntropyLoss",
      "status": "error",
      "error": "",
      "traceback": "lowering.py\", line 3199, in _pjit_lowering_rule\n    return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2357, in _gather_lowering_rule\n    assert indices_aval.shape == in_aval.shape + (1,)\nAssertionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 8.7
    },
    {
      "name": "96_HuberLoss",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.2039,
      "generated_ms": 0.5747,
      "speedup": 0.355,
      "original_std_ms": 0.0043,
      "generated_std_ms": 0.0043,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 7.7
    },
    {
      "name": "97_ScaledDotProductAttention",
      "status": "error",
      "error": "",
      "traceback": "and/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2398, in _transpose_lowering_rule\n    raise NotImplementedError\nNotImplementedError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.9
    },
    {
      "name": "98_KLDivLoss",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kldiv_kernel at /tmp/pallas_eval/generated/98_KLDivLoss_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (128, 128), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kldiv_kernel at /tmp/pallas_eval/generated/98_KLDivLoss_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (128, 128), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 9.5
    },
    {
      "name": "99_TripletMarginLoss",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.9204,
      "generated_ms": 2.9278,
      "speedup": 0.997,
      "original_std_ms": 0.8665,
      "generated_std_ms": 0.9304,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 10.3
    },
    {
      "name": "9_Tall_skinny_matrix_multiplication_",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.298,
      "generated_ms": 21.5241,
      "speedup": 0.153,
      "original_std_ms": 0.2272,
      "generated_std_ms": 0.2162,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level1",
      "eval_time_s": 19.3
    },
    {
      "name": "100_ConvTranspose3d_Clamp_Min_Divide",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.1101,
      "generated_ms": 34.8599,
      "speedup": 0.147,
      "original_std_ms": 0.5286,
      "generated_std_ms": 0.2262,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 50.9
    },
    {
      "name": "10_ConvTranspose2d_MaxPool_Hardtanh_Mean_Tanh",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 9.157,
      "generated_ms": 11.4826,
      "speedup": 0.797,
      "original_std_ms": 0.0111,
      "generated_std_ms": 0.0128,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 29.1
    },
    {
      "name": "11_ConvTranspose2d_BatchNorm_Tanh_MaxPool_GroupNorm",
      "correct": true,
      "max_diff": 1e-06,
      "correctness_reason": "ok",
      "original_ms": 1.2126,
      "generated_ms": 4.7993,
      "speedup": 0.253,
      "original_std_ms": 0.0729,
      "generated_std_ms": 0.0069,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 31.0
    },
    {
      "name": "12_Gemm_Multiply_LeakyReLU",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=268435456) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[268435456]{0}', space=vmem, size = 0x10000000, tag = 'operand span for operand 1'] :: main.1",
      "traceback": ":\n  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=268435456) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[268435456]{0}', space=vmem, size = 0x10000000, tag = 'operand span for operand 1'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 8.9
    },
    {
      "name": "13_ConvTranspose3d_Mean_Add_Softmax_Tanh_Scaling",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 13.1104,
      "generated_ms": 13.251,
      "speedup": 0.989,
      "original_std_ms": 0.0351,
      "generated_std_ms": 0.0678,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 152.5
    },
    {
      "name": "14_Gemm_Divide_Sum_Scaling",
      "correct": true,
      "max_diff": 0.005127,
      "correctness_reason": "ok",
      "original_ms": 1.1206,
      "generated_ms": 0.3221,
      "speedup": 3.479,
      "original_std_ms": 0.0077,
      "generated_std_ms": 0.004,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 62.9
    },
    {
      "name": "15_ConvTranspose3d_BatchNorm_Subtract",
      "correct": false,
      "max_diff": 27.545929,
      "correctness_reason": "values differ",
      "original_ms": 8.7306,
      "generated_ms": 10.1085,
      "speedup": 0.864,
      "original_std_ms": 0.1083,
      "generated_std_ms": 0.0477,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 118.8
    },
    {
      "name": "16_ConvTranspose2d_Mish_Add_Hardtanh_Scaling",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 8.1705,
      "generated_ms": 14.101,
      "speedup": 0.579,
      "original_std_ms": 0.0295,
      "generated_std_ms": 0.0297,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 22.7
    },
    {
      "name": "17_Conv2d_InstanceNorm_Divide",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 4.6367,
      "generated_ms": 25.865,
      "speedup": 0.179,
      "original_std_ms": 0.1691,
      "generated_std_ms": 0.2034,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 26.4
    },
    {
      "name": "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1456,
      "generated_ms": 0.1485,
      "speedup": 0.98,
      "original_std_ms": 0.0041,
      "generated_std_ms": 0.0059,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 8.1
    },
    {
      "name": "19_ConvTranspose2d_GELU_GroupNorm",
      "correct": false,
      "max_diff": 62.15794,
      "correctness_reason": "values differ",
      "original_ms": 12.3963,
      "generated_ms": 48.9015,
      "speedup": 0.253,
      "original_std_ms": 0.1127,
      "generated_std_ms": 0.0764,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 38.4
    },
    {
      "name": "1_Conv2D_ReLU_BiasAdd",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.1839,
      "generated_ms": 14.8656,
      "speedup": 0.147,
      "original_std_ms": 0.007,
      "generated_std_ms": 0.0173,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 24.0
    },
    {
      "name": "20_ConvTranspose3d_Sum_ResidualAdd_Multiply_ResidualAdd",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 4.4653,
      "generated_ms": 9.662,
      "speedup": 0.462,
      "original_std_ms": 0.4374,
      "generated_std_ms": 0.0144,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 38.0
    },
    {
      "name": "21_Conv2d_Add_Scale_Sigmoid_GroupNorm",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=258064), Blocked(block_size=1))) must have the same number of dimensions as the array shape (1024, 258064, 1).",
      "traceback": "ings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=258064), Blocked(block_size=1))) must have the same number of dimensions as the array shape (1024, 258064, 1).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.4
    },
    {
      "name": "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish",
      "status": "error",
      "error": "'Model' object has no attribute 'matmul_weight'",
      "traceback": "\"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 91, in eval_jaxkernelbench\n    ref_out = jax.jit(orig_model.forward)(*inputs)\n  File \"/tmp/pallas_eval/originals/22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish_original.py\", line 24, in forward\n    x = jnp.matmul(x, self.matmul_weight.T) + self.matmul_bias\nAttributeError: 'Model' object has no attribute 'matmul_weight'\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.4
    },
    {
      "name": "23_Conv3d_GroupNorm_Mean",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call mean_kernel at /tmp/pallas_eval/generated/23_Conv3d_GroupNorm_Mean_gpt53.py:46 has block shape (Blocked(block_size=1), Blocked(block_size=475200)), array shape (128, 475200), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call mean_kernel at /tmp/pallas_eval/generated/23_Conv3d_GroupNorm_Mean_gpt53.py:46 has block shape (Blocked(block_size=1), Blocked(block_size=475200)), array shape (128, 475200), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.0
    },
    {
      "name": "24_Conv3d_Min_Softmax",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call post_kernel at /tmp/pallas_eval/generated/24_Conv3d_Min_Softmax_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=22), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=24)), array shape (128, 22, 30, 30, 24), and index_map { lambda ; a:i32[] b:i32[] c:i32[]. let  in (a, 0:i32[], b, c, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " overall array. Block spec for args[0] in pallas_call post_kernel at /tmp/pallas_eval/generated/24_Conv3d_Min_Softmax_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=22), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=24)), array shape (128, 22, 30, 30, 24), and index_map { lambda ; a:i32[] b:i32[] c:i32[]. let  in (a, 0:i32[], b, c, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.2
    },
    {
      "name": "25_Conv2d_Min_Tanh_Tanh",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2983, in _run_body\n    args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 9.9
    },
    {
      "name": "26_ConvTranspose3d_Add_HardSwish",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 4.1189,
      "generated_ms": 29.3485,
      "speedup": 0.14,
      "original_std_ms": 0.6963,
      "generated_std_ms": 0.0394,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 29.1
    },
    {
      "name": "27_Conv3d_HardSwish_GroupNorm_Mean",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call hardswish_kernel at /tmp/pallas_eval/generated/27_Conv3d_HardSwish_GroupNorm_Mean_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=1024)), array shape (1024, 174928), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "e divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call hardswish_kernel at /tmp/pallas_eval/generated/27_Conv3d_HardSwish_GroupNorm_Mean_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=1024)), array shape (1024, 174928), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 16.6
    },
    {
      "name": "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply",
      "status": "error",
      "error": "Incompatible shapes for broadcasting: shapes=[(16384,), (128, 128)]",
      "traceback": "x/_src/numpy/ufuncs.py\", line 1234, in add\n    x, y = promote_args(\"add\", x, y)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 228, in promote_args\n    return promote_shapes(fun_name, *promote_dtypes(*args))\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py\", line 64, in promote_shapes\n    result_rank = len(lax.broadcast_shapes(*shapes))\nValueError: Incompatible shapes for broadcasting: shapes=[(16384,), (128, 128)]\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.0
    },
    {
      "name": "29_Matmul_Mish_Mish",
      "status": "error",
      "error": "when() takes 1 positional argument but 2 were given",
      "traceback": "ib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1210, in _trace_kernel_to_jaxpr\n    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/primitives.py\", line 874, in wrap_with_transforms\n    return f(*new_args)\n  File \"/tmp/pallas_eval/generated/29_Matmul_Mish_Mish_gpt53.py\", line 39, in kernel\n    pl.when(k_id == 0, init)\nTypeError: when() takes 1 positional argument but 2 were given\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.6
    },
    {
      "name": "2_ConvTranspose2d_BiasAdd_Clamp_Scaling_Clamp_Divide",
      "status": "error",
      "error": "Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 3 but `inputs` is a tuple of length 4, so the lengths do not match",
      "traceback": "hand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1687, in wrapped\n    kernel_args, grid_mapping = pallas_core.get_grid_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 1137, in get_grid_mapping\n    raise ValueError(\nValueError: Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 3 but `inputs` is a tuple of length 4, so the lengths do not match\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 11.7
    },
    {
      "name": "30_Gemm_GroupNorm_Hardtanh",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/30_Gemm_GroupNorm_Hardtanh_gpt53.py:35 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=512)), array shape (1024, 16, 512), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/30_Gemm_GroupNorm_Hardtanh_gpt53.py:35 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=512)), array shape (1024, 16, 512), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.2
    },
    {
      "name": "31_Conv2d_Min_Add_Multiply",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.1998,
      "generated_ms": 19.8529,
      "speedup": 0.111,
      "original_std_ms": 0.0052,
      "generated_std_ms": 0.0191,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 23.9
    },
    {
      "name": "32_Conv2d_Scaling_Min",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.3754,
      "generated_ms": 17.2627,
      "speedup": 0.196,
      "original_std_ms": 0.0156,
      "generated_std_ms": 0.2023,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 18.6
    },
    {
      "name": "33_Gemm_Scale_BatchNorm",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1573,
      "generated_ms": 2.2115,
      "speedup": 0.071,
      "original_std_ms": 0.0048,
      "generated_std_ms": 0.6459,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.9
    },
    {
      "name": "34_ConvTranspose3d_LayerNorm_GELU_Scaling",
      "correct": true,
      "max_diff": 2e-06,
      "correctness_reason": "ok",
      "original_ms": 22.4771,
      "generated_ms": 30.0373,
      "speedup": 0.748,
      "original_std_ms": 2.9514,
      "generated_std_ms": 0.0226,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 145.5
    },
    {
      "name": "35_Conv2d_Subtract_HardSwish_MaxPool_Mish",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.0587,
      "generated_ms": 17.2773,
      "speedup": 0.177,
      "original_std_ms": 0.009,
      "generated_std_ms": 0.0166,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 24.7
    },
    {
      "name": "36_ConvTranspose2d_Min_Sum_GELU_Add",
      "status": "error",
      "error": "Block shape for args[1] (= (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1))) must have the same number of dimensions as the array shape (1, 1, 1).",
      "traceback": "hand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[1] (= (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=1))) must have the same number of dimensions as the array shape (1, 1, 1).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 11.7
    },
    {
      "name": "37_Matmul_Swish_Sum_GroupNorm",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.1718,
      "generated_ms": 10.1946,
      "speedup": 0.507,
      "original_std_ms": 0.4577,
      "generated_std_ms": 1.1826,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 11.3
    },
    {
      "name": "38_ConvTranspose3d_AvgPool_Clamp_Softmax_Multiply",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 17.9249,
      "generated_ms": 28.7182,
      "speedup": 0.624,
      "original_std_ms": 0.0263,
      "generated_std_ms": 0.0271,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 59.9
    },
    {
      "name": "39_Gemm_Scale_BatchNorm",
      "status": "error",
      "error": "`indices` must not be longer than `shape`: indices=(None, slice(None, None, None)), shape=(128,)",
      "traceback": "l\n    out = acc + b_ref[None, :]\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/state/indexing.py\", line 235, in from_indices_shape\n    raise ValueError(\"`indices` must not be longer than `shape`: \"\nValueError: `indices` must not be longer than `shape`: indices=(None, slice(None, None, None)), shape=(128,)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.4
    },
    {
      "name": "3_ConvTranspose3d_Sum_LayerNorm_AvgPool_GELU",
      "status": "error",
      "error": "The kernel function in the pallas_call kernel_fn at /tmp/pallas_eval/generated/3_ConvTranspose3d_Sum_LayerNorm_AvgPool_GELU_gpt53.py:65 captures constants [ShapedArray(float32[], weak_type=True), ShapedArray(float32[64]), ShapedArray(float32[64])]. You should pass them as inputs",
      "traceback": "n wrapped\n    jaxpr, consts = _trace_kernel_to_jaxpr(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1215, in _trace_kernel_to_jaxpr\n    raise ValueError(\nValueError: The kernel function in the pallas_call kernel_fn at /tmp/pallas_eval/generated/3_ConvTranspose3d_Sum_LayerNorm_AvgPool_GELU_gpt53.py:65 captures constants [ShapedArray(float32[], weak_type=True), ShapedArray(float32[64]), ShapedArray(float32[64])]. You should pass them as inputs\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 18.9
    },
    {
      "name": "40_Matmul_Scaling_ResidualAdd",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 1.3011,
      "generated_ms": 3.7942,
      "speedup": 0.343,
      "original_std_ms": 0.006,
      "generated_std_ms": 0.0082,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.3
    },
    {
      "name": "41_Gemm_BatchNorm_GELU_ReLU",
      "correct": true,
      "max_diff": 2e-06,
      "correctness_reason": "ok",
      "original_ms": 2.3608,
      "generated_ms": 3.5678,
      "speedup": 0.662,
      "original_std_ms": 0.3602,
      "generated_std_ms": 0.0103,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 15.6
    },
    {
      "name": "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call fused_kernel at /tmp/pallas_eval/generated/42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (16, 1), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "y 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call fused_kernel at /tmp/pallas_eval/generated/42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (16, 1), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 17.2
    },
    {
      "name": "43_Conv3d_Max_LogSumExp_ReLU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.0574,
      "generated_ms": 5.7579,
      "speedup": 0.878,
      "original_std_ms": 0.0648,
      "generated_std_ms": 0.0093,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 106.6
    },
    {
      "name": "44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call mean_kernel at /tmp/pallas_eval/generated/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_gpt53.py:43 has block shape (Blocked(block_size=1), Blocked(block_size=65536)), array shape (2048, 65536), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call mean_kernel at /tmp/pallas_eval/generated/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean_gpt53.py:43 has block shape (Blocked(block_size=1), Blocked(block_size=65536)), array shape (2048, 65536), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 11.7
    },
    {
      "name": "45_Gemm_Sigmoid_LogSumExp",
      "status": "error",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int32[]>with<DynamicJaxprTrace>, Traced<~int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "al/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int32[]>with<DynamicJaxprTrace>, Traced<~int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.6
    },
    {
      "name": "46_Conv2d_Subtract_Tanh_Subtract_AvgPool",
      "correct": false,
      "max_diff": 0.331059,
      "correctness_reason": "values differ",
      "original_ms": 3.0606,
      "generated_ms": 14.2214,
      "speedup": 0.215,
      "original_std_ms": 0.0155,
      "generated_std_ms": 0.0393,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 22.2
    },
    {
      "name": "47_Conv3d_Mish_Tanh",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.602,
      "generated_ms": 11.551,
      "speedup": 0.225,
      "original_std_ms": 0.0059,
      "generated_std_ms": 0.0174,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 126.1
    },
    {
      "name": "48_Conv3d_Scaling_Tanh_Multiply_Sigmoid",
      "correct": false,
      "max_diff": 0.5,
      "correctness_reason": "values differ",
      "original_ms": 0.7276,
      "generated_ms": 7.8827,
      "speedup": 0.092,
      "original_std_ms": 0.0059,
      "generated_std_ms": 0.0155,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 26.1
    },
    {
      "name": "49_ConvTranspose3d_Softmax_Sigmoid",
      "correct": false,
      "max_diff": 0.503906,
      "correctness_reason": "values differ",
      "original_ms": 5.5273,
      "generated_ms": 4.6004,
      "speedup": 1.201,
      "original_std_ms": 0.6541,
      "generated_std_ms": 0.2501,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 33.1
    },
    {
      "name": "4_Conv2d_Mish_Mish",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.5319,
      "generated_ms": 38.0141,
      "speedup": 0.146,
      "original_std_ms": 0.0086,
      "generated_std_ms": 0.0265,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 30.7
    },
    {
      "name": "50_ConvTranspose3d_Scaling_AvgPool_BiasAdd_Scaling",
      "status": "error",
      "error": "The kernel function in the pallas_call <lambda> at /tmp/pallas_eval/generated/50_ConvTranspose3d_Scaling_AvgPool_BiasAdd_Scaling_gpt53.py:20 captures constants [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. You should pass them as inputs",
      "traceback": "1715, in wrapped\n    jaxpr, consts = _trace_kernel_to_jaxpr(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1215, in _trace_kernel_to_jaxpr\n    raise ValueError(\nValueError: The kernel function in the pallas_call <lambda> at /tmp/pallas_eval/generated/50_ConvTranspose3d_Scaling_AvgPool_BiasAdd_Scaling_gpt53.py:20 captures constants [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. You should pass them as inputs\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 12.8
    },
    {
      "name": "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (2048, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd_gpt53.py:7 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (2048, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.0
    },
    {
      "name": "52_Conv2d_Activation_BatchNorm",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/52_Conv2d_Activation_BatchNorm_gpt53.py:8 has block shape (Blocked(block_size=128), Blocked(block_size=64)), array shape (128, 1016064), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/52_Conv2d_Activation_BatchNorm_gpt53.py:8 has block shape (Blocked(block_size=128), Blocked(block_size=64)), array shape (128, 1016064), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 14.9
    },
    {
      "name": "53_Gemm_Scaling_Hardtanh_GELU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.2107,
      "generated_ms": 4.12,
      "speedup": 0.051,
      "original_std_ms": 0.0056,
      "generated_std_ms": 0.9575,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 8.3
    },
    {
      "name": "54_Conv2d_Multiply_LeakyReLU_GELU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 9.0944,
      "generated_ms": 21.1922,
      "speedup": 0.429,
      "original_std_ms": 0.2503,
      "generated_std_ms": 0.0805,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 23.9
    },
    {
      "name": "55_Matmul_MaxPool_Sum_Scale",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2983, in _run_body\n    args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: dynamic_slice. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 27.6
    },
    {
      "name": "56_Matmul_Sigmoid_Sum",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=4294967296) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[4294967296]{0}', space=vmem, size = 0x100000000, tag = 'operand span for operand 1'] :: main.1",
      "traceback": " File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=4294967296) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[4294967296]{0}', space=vmem, size = 0x100000000, tag = 'operand span for operand 1'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 73.1
    },
    {
      "name": "57_Conv2d_ReLU_HardSwish",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.546,
      "generated_ms": 16.1261,
      "speedup": 0.034,
      "original_std_ms": 0.0047,
      "generated_std_ms": 0.0188,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 15.1
    },
    {
      "name": "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=16), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=32)), array shape (128, 16, 31, 63, 63), and index_map { lambda ; a:i32[] b:i32[] c:i32[] d:i32[]. let  in (a, 0:i32[], b, c, d) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "gs[0] in pallas_call kernel_fn at /tmp/pallas_eval/generated/58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=16), Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=32)), array shape (128, 16, 31, 63, 63), and index_map { lambda ; a:i32[] b:i32[] c:i32[] d:i32[]. let  in (a, 0:i32[], b, c, d) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 12.9
    },
    {
      "name": "59_Matmul_Swish_Scaling",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function kernel at /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:38 for pallas_call kernel. This value became a tracer due to JAX operations on these lines:\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:i32[] = add b 1:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:bool[] = ge b 0:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:i32[] = add b 1:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:i32[] = mul b 128:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\n\n  operation a:bool[] = ge b 0:i32[]\n    from line /tmp/pallas_eval/generated/59_Matmul_Swish_Scaling_gpt53.py:41 (kernel)\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.5
    },
    {
      "name": "5_ConvTranspose2d_Subtract_Tanh",
      "correct": false,
      "max_diff": 1.183219,
      "correctness_reason": "values differ",
      "original_ms": 11.1612,
      "generated_ms": 108.4837,
      "speedup": 0.103,
      "original_std_ms": 0.1011,
      "generated_std_ms": 0.2263,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 44.1
    },
    {
      "name": "60_ConvTranspose3d_Swish_GroupNorm_HardSwish",
      "correct": false,
      "max_diff": 4.345818,
      "correctness_reason": "values differ",
      "original_ms": 5.1941,
      "generated_ms": 89.7477,
      "speedup": 0.058,
      "original_std_ms": 0.0114,
      "generated_std_ms": 0.0433,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 36.2
    },
    {
      "name": "61_ConvTranspose3d_ReLU_GroupNorm",
      "correct": true,
      "max_diff": 5e-06,
      "correctness_reason": "ok",
      "original_ms": 1.7361,
      "generated_ms": 42.4043,
      "speedup": 0.041,
      "original_std_ms": 0.004,
      "generated_std_ms": 0.0197,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 86.7
    },
    {
      "name": "62_Matmul_GroupNorm_LeakyReLU_Sum",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/62_Matmul_GroupNorm_LeakyReLU_Sum_gpt53.py:59 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (1024, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/62_Matmul_GroupNorm_LeakyReLU_Sum_gpt53.py:59 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (1024, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.0
    },
    {
      "name": "63_Gemm_ReLU_Divide",
      "status": "error",
      "error": "INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 4: XLA layout ({0:T(1024)S(1)}) does not match Mosaic layout ({0:T(128)S(1)}) for an operand of shape f32[8192]. Try changing your kernel block shape to (1024) to align with the XLA layout.",
      "traceback": "elbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 4: XLA layout ({0:T(1024)S(1)}) does not match Mosaic layout ({0:T(128)S(1)}) for an operand of shape f32[8192]. Try changing your kernel block shape to (1024) to align with the XLA layout.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.2
    },
    {
      "name": "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1306,
      "generated_ms": 0.132,
      "speedup": 0.989,
      "original_std_ms": 0.004,
      "generated_std_ms": 0.0042,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.6
    },
    {
      "name": "65_Conv2d_AvgPool_Sigmoid_Sum",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/65_Conv2d_AvgPool_Sigmoid_Sum_gpt53.py:20 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (128, 1), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/65_Conv2d_AvgPool_Sigmoid_Sum_gpt53.py:20 has block shape (Blocked(block_size=1), Blocked(block_size=1)), array shape (128, 1), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 9.0
    },
    {
      "name": "66_Matmul_Dropout_Softmax",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=1073741824) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[1073741824]{0}', space=vmem, size = 0x40000000, tag = 'operand span for operand 1'] :: main.1",
      "traceback": "  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=1073741824) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[1073741824]{0}', space=vmem, size = 0x40000000, tag = 'operand span for operand 1'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 17.5
    },
    {
      "name": "67_Conv2d_GELU_GlobalAvgPool",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call post_kernel at /tmp/pallas_eval/generated/67_Conv2d_GELU_GlobalAvgPool_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": " your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call post_kernel at /tmp/pallas_eval/generated/67_Conv2d_GELU_GlobalAvgPool_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 11.7
    },
    {
      "name": "68_Matmul_Min_Subtract",
      "status": "error",
      "error": "Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 3 but `inputs` is a tuple of length 4, so the lengths do not match",
      "traceback": "hand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1687, in wrapped\n    kernel_args, grid_mapping = pallas_core.get_grid_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 1137, in get_grid_mapping\n    raise ValueError(\nValueError: Pytree for `in_specs` and `inputs` do not match. There are 1 mismatches, including:\n    * `in_specs` is a tuple of length 3 but `inputs` is a tuple of length 4, so the lengths do not match\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.2
    },
    {
      "name": "69_Conv2d_HardSwish_ReLU",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/69_Conv2d_HardSwish_ReLU_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2032128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "f your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for outputs in pallas_call kernel_fn at /tmp/pallas_eval/generated/69_Conv2d_HardSwish_ReLU_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2032128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.1
    },
    {
      "name": "6_Conv3d_Softmax_MaxPool_MaxPool",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.5367,
      "generated_ms": 22.7687,
      "speedup": 0.024,
      "original_std_ms": 0.0063,
      "generated_std_ms": 0.0229,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 21.8
    },
    {
      "name": "70_Gemm_Sigmoid_Scaling_ResidualAdd",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1576,
      "generated_ms": 2.2154,
      "speedup": 0.071,
      "original_std_ms": 0.0035,
      "generated_std_ms": 0.6437,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.8
    },
    {
      "name": "71_Conv2d_Divide_LeakyReLU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.5481,
      "generated_ms": 16.1322,
      "speedup": 0.034,
      "original_std_ms": 0.0043,
      "generated_std_ms": 0.0172,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 14.9
    },
    {
      "name": "72_ConvTranspose3d_BatchNorm_AvgPool_AvgPool",
      "status": "error",
      "error": "RESOURCE_EXHAUSTED: Allocation (size=2113929216) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[2113929216]{0}', space=vmem, size = 0x7e000000, tag = 'operand span for operand 0'] :: main.1",
      "traceback": "  File \"/tmp/pallas_eval/eval_harness.py\", line 168, in main\n    result = eval_jaxkernelbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=2113929216) would exceed memory (size=134217728) :: #allocation2 [shape = 'u8[2113929216]{0}', space=vmem, size = 0x7e000000, tag = 'operand span for operand 0'] :: main.1\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 61.0
    },
    {
      "name": "73_Conv2d_BatchNorm_Scaling",
      "status": "error",
      "error": "`broadcast_to` is a Triton-specific primitive. Please consider using `jnp.broadcast_to` instead.",
      "traceback": "in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1913, in _broadcast_to_lowering_rule\n    raise RuntimeError(\nRuntimeError: `broadcast_to` is a Triton-specific primitive. Please consider using `jnp.broadcast_to` instead.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.5
    },
    {
      "name": "74_ConvTranspose3d_LeakyReLU_Multiply_LeakyReLU_Max",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 3.5829,
      "generated_ms": 9.8834,
      "speedup": 0.363,
      "original_std_ms": 0.7571,
      "generated_std_ms": 0.0137,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 39.5
    },
    {
      "name": "75_Gemm_GroupNorm_Min_BiasAdd",
      "status": "error",
      "error": "INTERNAL: Mosaic failed to compile TPU kernel: infer-vector-layout: unsupported shape cast\n\nThe MLIR operation involved:\n  %21 = \"vector.shape_cast\"(%17) : (vector<128x8192xf32>) -> vector<128x512x16xf32>\n\nPlease report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke\n",
      "traceback": " y = pl.pallas_call(\n  File \"/tmp/pallas_eval/generated/75_Gemm_GroupNorm_Min_BiasAdd_gpt53.py\", line 17, in fused_kernel\n    xg = x.reshape(B, G, group_size)\njax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: infer-vector-layout: unsupported shape cast\n\nThe MLIR operation involved:\n  %21 = \"vector.shape_cast\"(%17) : (vector<128x8192xf32>) -> vector<128x512x16xf32>\n\nPlease report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke\n\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.5
    },
    {
      "name": "76_Gemm_Add_ReLU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1709,
      "generated_ms": 2.203,
      "speedup": 0.078,
      "original_std_ms": 0.0149,
      "generated_std_ms": 0.5972,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.8
    },
    {
      "name": "77_ConvTranspose3d_Scale_BatchNorm_GlobalAvgPool",
      "status": "error",
      "error": "The kernel function in the pallas_call kernel_fn at /tmp/pallas_eval/generated/77_ConvTranspose3d_Scale_BatchNorm_GlobalAvgPool_gpt53.py:48 captures constants [ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128])]. You should pass them as inputs",
      "traceback": " consts = _trace_kernel_to_jaxpr(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\", line 1215, in _trace_kernel_to_jaxpr\n    raise ValueError(\nValueError: The kernel function in the pallas_call kernel_fn at /tmp/pallas_eval/generated/77_ConvTranspose3d_Scale_BatchNorm_GlobalAvgPool_gpt53.py:48 captures constants [ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128]), ShapedArray(float32[128])]. You should pass them as inputs\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 19.6
    },
    {
      "name": "78_ConvTranspose3d_Max_Max_Sum",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 6.7623,
      "generated_ms": 6.7659,
      "speedup": 0.999,
      "original_std_ms": 0.0134,
      "generated_std_ms": 0.0117,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 193.5
    },
    {
      "name": "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.3999,
      "generated_ms": 11.3876,
      "speedup": 0.035,
      "original_std_ms": 0.0068,
      "generated_std_ms": 0.0212,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 23.9
    },
    {
      "name": "7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd",
      "correct": false,
      "max_diff": 0.5,
      "correctness_reason": "values differ",
      "original_ms": 2.6032,
      "generated_ms": 23.5718,
      "speedup": 0.11,
      "original_std_ms": 0.4146,
      "generated_std_ms": 0.038,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 18.6
    },
    {
      "name": "80_Gemm_Max_Subtract_GELU",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/80_Gemm_Max_Subtract_GELU_gpt53.py:30 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": " in from_slice\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/80_Gemm_Max_Subtract_GELU_gpt53.py:30 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.8
    },
    {
      "name": "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp",
      "status": "error",
      "error": "Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_gpt53.py:31 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError",
      "traceback": "ce\n    start, step, size = core.canonicalize_slice(slc, size)\njax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].\nThe error occurred while tracing the function body at /tmp/pallas_eval/generated/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_gpt53.py:31 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.9
    },
    {
      "name": "82_Conv2d_Tanh_Scaling_BiasAdd_Max",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "status": "error",
      "error": "ERROR: Command timed out"
    },
    {
      "name": "83_Conv3d_GroupNorm_Min_Clamp_Dropout",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 1.7027,
      "generated_ms": 16.3158,
      "speedup": 0.104,
      "original_std_ms": 0.0919,
      "generated_std_ms": 0.1225,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 39.4
    },
    {
      "name": "84_Gemm_BatchNorm_Scaling_Softmax",
      "status": "error",
      "error": "Unimplemented primitive in Pallas TPU lowering for KernelType.TC: stop_gradient. Please file an issue on https://github.com/jax-ml/jax/issues.",
      "traceback": "f(*func_args, **func_kwargs)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1046, in body_func\n    return jaxpr_subcomp(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1190, in jaxpr_subcomp\n    raise NotImplementedError(\nNotImplementedError: Unimplemented primitive in Pallas TPU lowering for KernelType.TC: stop_gradient. Please file an issue on https://github.com/jax-ml/jax/issues.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.8
    },
    {
      "name": "85_Conv2d_GroupNorm_Scale_MaxPool_Clamp",
      "status": "error",
      "error": "cannot reshape array of shape (1, 64, 1, 1) (size 64) into shape (128, 16, 63504) (size 130056192)",
      "traceback": "ryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 313, in _reshape\n    newshape = _compute_newshape(self, args[0] if len(args) == 1 else args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 474, in _compute_newshape\n    raise TypeError(f\"cannot reshape array of shape {arr.shape} (size {arr.size}) \"\nTypeError: cannot reshape array of shape (1, 64, 1, 1) (size 64) into shape (128, 16, 63504) (size 130056192)\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 9.8
    },
    {
      "name": "86_Matmul_Divide_GELU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1571,
      "generated_ms": 2.2128,
      "speedup": 0.071,
      "original_std_ms": 0.0038,
      "generated_std_ms": 0.6605,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.7
    },
    {
      "name": "87_Conv2d_Subtract_Subtract_Mish",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/87_Conv2d_Subtract_Subtract_Mish_gpt53.py:51 has block shape (Blocked(block_size=128), Blocked(block_size=4)), array shape (8192, 64516), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "lock shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call kernel at /tmp/pallas_eval/generated/87_Conv2d_Subtract_Subtract_Mish_gpt53.py:51 has block shape (Blocked(block_size=128), Blocked(block_size=4)), array shape (8192, 64516), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.8
    },
    {
      "name": "88_Gemm_GroupNorm_Swish_Multiply_Swish",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=8), Blocked(block_size=32))) must have the same number of dimensions as the array shape (1024, 256, 32).",
      "traceback": "k_mappings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=8), Blocked(block_size=32))) must have the same number of dimensions as the array shape (1024, 256, 32).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 6.7
    },
    {
      "name": "89_ConvTranspose3d_MaxPool_Softmax_Subtract_Swish_Max",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.5431,
      "generated_ms": 11.6497,
      "speedup": 0.218,
      "original_std_ms": 0.5409,
      "generated_std_ms": 0.0161,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 34.2
    },
    {
      "name": "8_Conv3d_Divide_Max_GlobalAvgPool_BiasAdd_Sum",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "status": "error",
      "error": "ERROR: Command timed out"
    },
    {
      "name": "90_Conv3d_LeakyReLU_Sum_Clamp_GELU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 2.1208,
      "generated_ms": 34.6601,
      "speedup": 0.061,
      "original_std_ms": 0.3068,
      "generated_std_ms": 0.1424,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 19.9
    },
    {
      "name": "91_ConvTranspose2d_Softmax_BiasAdd_Scaling_Sigmoid",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 5.8515,
      "generated_ms": 49.3765,
      "speedup": 0.119,
      "original_std_ms": 0.7433,
      "generated_std_ms": 0.0339,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 26.7
    },
    {
      "name": "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2032128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "y 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call fused_kernel at /tmp/pallas_eval/generated/92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp_gpt53.py:8 has block shape (Blocked(block_size=1), Blocked(block_size=64)), array shape (2032128, 64), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 10.0
    },
    {
      "name": "93_ConvTranspose2d_Add_Min_GELU_Multiply",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 1.2509,
      "generated_ms": 80.5851,
      "speedup": 0.016,
      "original_std_ms": 0.0032,
      "generated_std_ms": 0.0331,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 31.2
    },
    {
      "name": "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=1), Blocked(block_size=8192))) must have the same number of dimensions as the array shape (8192,).",
      "traceback": "n_block_mappings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=1), Blocked(block_size=8192))) must have the same number of dimensions as the array shape (8192,).\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.1
    },
    {
      "name": "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh",
      "status": "error",
      "error": "INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 4: XLA layout ({0:T(1024)S(1)}) does not match Mosaic layout ({0:T(128)S(1)}) for an operand of shape f32[8192]. Try changing your kernel block shape to (1024) to align with the XLA layout.",
      "traceback": "elbench(args.original, args.generated, args.name)\n  File \"/tmp/pallas_eval/eval_harness.py\", line 95, in eval_jaxkernelbench\n    test_out = jax.jit(gen_model.forward)(*inputs)\njaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 4: XLA layout ({0:T(1024)S(1)}) does not match Mosaic layout ({0:T(128)S(1)}) for an operand of shape f32[8192]. Try changing your kernel block shape to (1024) to align with the XLA layout.\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.0
    },
    {
      "name": "96_ConvTranspose3d_Multiply_Max_GlobalAvgPool_Clamp",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 1.6996,
      "generated_ms": 113.7265,
      "speedup": 0.015,
      "original_std_ms": 0.3746,
      "generated_std_ms": 0.0249,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 38.1
    },
    {
      "name": "97_Matmul_BatchNorm_BiasAdd_Divide_Swish",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.16,
      "generated_ms": 2.2094,
      "speedup": 0.072,
      "original_std_ms": 0.0054,
      "generated_std_ms": 0.6562,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.8
    },
    {
      "name": "98_Matmul_AvgPool_GELU_Scale_Max",
      "status": "error",
      "error": "INTERNAL: Mosaic failed to compile TPU kernel: infer-vector-layout: unsupported shape cast\n\nThe MLIR operation involved:\n  %20 = \"vector.shape_cast\"(%19) : (vector<128x8192xf32>) -> vector<128x512x16xf32>\n\nPlease report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke\n",
      "traceback": "f: fused_kernel(\n  File \"/tmp/pallas_eval/generated/98_Matmul_AvgPool_GELU_Scale_Max_gpt53.py\", line 20, in fused_kernel\n    y = jnp.reshape(y, (B, new_F, K))\njax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: infer-vector-layout: unsupported shape cast\n\nThe MLIR operation involved:\n  %20 = \"vector.shape_cast\"(%19) : (vector<128x8192xf32>) -> vector<128x512x16xf32>\n\nPlease report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke\n\n",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.3
    },
    {
      "name": "99_Matmul_GELU_Softmax",
      "correct": true,
      "max_diff": 0.00769,
      "correctness_reason": "ok",
      "original_ms": 0.1669,
      "generated_ms": 2.2071,
      "speedup": 0.076,
      "original_std_ms": 0.0044,
      "generated_std_ms": 0.6374,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.8
    },
    {
      "name": "9_Matmul_Subtract_Multiply_ReLU",
      "correct": true,
      "max_diff": 0.0,
      "correctness_reason": "ok",
      "original_ms": 0.1576,
      "generated_ms": 2.1976,
      "speedup": 0.072,
      "original_std_ms": 0.004,
      "generated_std_ms": 0.6005,
      "status": "success",
      "model": "gpt53",
      "suite": "jaxkernelbench",
      "level": "level2",
      "eval_time_s": 7.9
    },
    {
      "name": "cross_entropy",
      "status": "error",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int32[]>with<DynamicJaxprTrace>, Traced<~int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "al/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int32[]>with<DynamicJaxprTrace>, Traced<~int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 14.8
    },
    {
      "name": "flash_attention",
      "status": "error",
      "error": "dot_general requires contracting dimensions to have the same shape, got (128,) and (2048,).",
      "traceback": "lib/python3.10/site-packages/jax/_src/pallas/primitives.py\", line 874, in wrap_with_transforms\n    return f(*new_args)\n  File \"/tmp/pallas_eval/generated/flash_attention_gpt53.py\", line 37, in attention_kernel\n    attn = jnp.matmul(q, k.T) * scale\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/tensor_contractions.py\", line 245, in matmul\n    out = lax.dot_general(\nTypeError: dot_general requires contracting dimensions to have the same shape, got (128,) and (2048,).\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 13.1
    },
    {
      "name": "flex_attention",
      "correct": true,
      "max_diff": 8e-06,
      "correctness_reason": "ok",
      "original_ms": 2.929,
      "generated_ms": 2.8956,
      "speedup": 1.012,
      "original_std_ms": 0.497,
      "generated_std_ms": 0.0296,
      "status": "success",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 21.0
    },
    {
      "name": "gemm",
      "correct": true,
      "max_diff": 0.03125,
      "correctness_reason": "ok",
      "original_ms": 5.4762,
      "generated_ms": 28.6178,
      "speedup": 0.191,
      "original_std_ms": 0.0104,
      "generated_std_ms": 0.3434,
      "status": "success",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 16.9
    },
    {
      "name": "gqa_attention",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=2048), Blocked(block_size=128))) must have the same number of dimensions as the array shape (128, 2048, 128).",
      "traceback": "pings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=2048), Blocked(block_size=128))) must have the same number of dimensions as the array shape (128, 2048, 128).\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 15.3
    },
    {
      "name": "mamba2_ssd",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[3] in pallas_call _ssd_kernel at /tmp/pallas_eval/generated/mamba2_ssd_gpt53.py:31 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=2048)), array shape (1, 64, 2048), and index_map let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let\n    d:i32[] = select_n a c b\n  in (d,) } in\n{ lambda ; e:i32[]. let\n    f:i32[] = pjit[\n      name=floor_divide\n      jaxpr={ lambda ; e:i32[] g:i32[]. let\n          h:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g\n          i:i32[] = div e h\n          j:i32[] = sign e\n          k:i32[] = sign h\n          l:bool[] = ne j k\n          m:i32[] = rem e h\n          n:bool[] = ne m 0:i32[]\n          o:bool[] = and l n\n          p:i32[] = sub i 1:i32[]\n          f:i32[] = pjit[name=_where jaxpr=_where] o p i\n        in (f,) }\n    ] e 64:i32[]\n    q:i32[] = pjit[\n      name=remainder\n      jaxpr={ lambda ; e:i32[] r:i32[]. let\n          s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] r\n          t:bool[] = eq s 0:i32[]\n          u:i32[] = pjit[name=_where jaxpr=_where] t 1:i32[] s\n          v:i32[] = rem e u\n          w:bool[] = ne v 0:i32[]\n          x:bool[] = lt v 0:i32[]\n          y:bool[] = lt u 0:i32[]\n          z:bool[] = ne x y\n          ba:bool[] = and z w\n          bb:i32[] = add v u\n          q:i32[] = select_n ba v bb\n        in (q,) }\n    ] e 64:i32[]\n  in (f, q, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "q s 0:i32[]\n          u:i32[] = pjit[name=_where jaxpr=_where] t 1:i32[] s\n          v:i32[] = rem e u\n          w:bool[] = ne v 0:i32[]\n          x:bool[] = lt v 0:i32[]\n          y:bool[] = lt u 0:i32[]\n          z:bool[] = ne x y\n          ba:bool[] = and z w\n          bb:i32[] = add v u\n          q:i32[] = select_n ba v bb\n        in (q,) }\n    ] e 64:i32[]\n  in (f, q, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 9.4
    },
    {
      "name": "megablox_gmm",
      "status": "error",
      "error": "Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).",
      "traceback": "r_update\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/indexing.py\", line 929, in index_to_gather\n    raise IndexError(msg)\nIndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 9.6
    },
    {
      "name": "mla_attention",
      "correct": true,
      "max_diff": 0.007812,
      "correctness_reason": "ok",
      "original_ms": 4.5167,
      "generated_ms": 7.7222,
      "speedup": 0.585,
      "original_std_ms": 0.1151,
      "generated_std_ms": 0.3572,
      "status": "success",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 21.2
    },
    {
      "name": "paged_attention",
      "status": "error",
      "error": "'F32Type' object has no attribute 'element_type'",
      "traceback": "ody\n    args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 1171, in jaxpr_subcomp\n    ans = lowering_rules[ctx.kernel_type][eqn.primitive](\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py\", line 2044, in _dot_general_lowering_rule\n    val_type = out_type.element_type\nAttributeError: 'F32Type' object has no attribute 'element_type'\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 14.2
    },
    {
      "name": "ragged_dot",
      "correct": true,
      "max_diff": 0.03125,
      "correctness_reason": "ok",
      "original_ms": 1.3736,
      "generated_ms": 8.1334,
      "speedup": 0.169,
      "original_std_ms": 0.0069,
      "generated_std_ms": 0.0159,
      "status": "success",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 20.2
    },
    {
      "name": "ragged_paged_attention",
      "status": "error",
      "error": "The __index__() method was called on traced array with shape int32[]\nThe error occurred while tracing the function workload at /tmp/pallas_eval/originals/ragged_paged_attention_original.py:60 for jit. This concrete value was not available in Python because it depends on the value of the argument num_seqs.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError",
      "traceback": "ginal.py\", line 74, in workload\n    for i in range(num_seqs[0]):\njax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]\nThe error occurred while tracing the function workload at /tmp/pallas_eval/originals/ragged_paged_attention_original.py:60 for jit. This concrete value was not available in Python because it depends on the value of the argument num_seqs.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 18.8
    },
    {
      "name": "retnet_retention",
      "status": "error",
      "error": "Block shape for args[0] (= (Blocked(block_size=128), Blocked(block_size=256))) must have the same number of dimensions as the array shape (16, 2048, 256).",
      "traceback": "appings = map(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 999, in _convert_block_spec_to_block_mapping\n    return block_spec.to_block_mapping(\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/core.py\", line 512, in to_block_mapping\n    raise ValueError(\nValueError: Block shape for args[0] (= (Blocked(block_size=128), Blocked(block_size=256))) must have the same number of dimensions as the array shape (16, 2048, 256).\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 13.8
    },
    {
      "name": "rms_norm",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call rmsnorm_kernel at /tmp/pallas_eval/generated/rms_norm_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (2048, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "mensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call rmsnorm_kernel at /tmp/pallas_eval/generated/rms_norm_gpt53.py:34 has block shape (Blocked(block_size=1), Blocked(block_size=8192)), array shape (2048, 8192), and index_map { lambda ; a:i32[]. let  in (a, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 7.7
    },
    {
      "name": "sparse_attention",
      "status": "error",
      "error": "The Pallas TPU lowering currently requires that the last two dimensions of your block shape are divisible by 8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call attention_kernel at /tmp/pallas_eval/generated/sparse_attention_gpt53.py:35 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=128)), array shape (64, 2048, 128), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec",
      "traceback": "8 and 128 respectively, or be equal to the respective dimensions of the overall array. Block spec for args[0] in pallas_call attention_kernel at /tmp/pallas_eval/generated/sparse_attention_gpt53.py:35 has block shape (Blocked(block_size=1), Blocked(block_size=1), Blocked(block_size=128)), array shape (64, 2048, 128), and index_map { lambda ; a:i32[] b:i32[]. let  in (a, b, 0:i32[]) }, in memory space None.\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec\n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 14.8
    },
    {
      "name": "sparse_moe",
      "status": "error",
      "error": "Invalid dtype for `swap`. Ref dtype: bfloat16. Value dtype: float32. ",
      "traceback": "home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/pallas/primitives.py\", line 874, in wrap_with_transforms\n    return f(*new_args)\n  File \"/tmp/pallas_eval/generated/sparse_moe_gpt53.py\", line 42, in combine_kernel\n    out_ref[...] = acc\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\nValueError: Invalid dtype for `swap`. Ref dtype: bfloat16. Value dtype: float32. \n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 21.9
    },
    {
      "name": "swiglu_mlp",
      "correct": false,
      "max_diff": 0.25,
      "correctness_reason": "values differ",
      "original_ms": 4.0776,
      "generated_ms": 5.3419,
      "speedup": 0.763,
      "original_std_ms": 0.008,
      "generated_std_ms": 0.006,
      "status": "success",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 21.4
    },
    {
      "name": "triangle_multiplication",
      "status": "error",
      "error": "Invalid dtype for `swap`. Ref dtype: bfloat16. Value dtype: float32. ",
      "traceback": ".local/lib/python3.10/site-packages/jax/_src/pallas/primitives.py\", line 874, in wrap_with_transforms\n    return f(*new_args)\n  File \"/tmp/pallas_eval/generated/triangle_multiplication_gpt53.py\", line 57, in kernel_fn\n    o_ref[i, :, :] = out * gate\n  File \"/home/aryatschand/.local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py\", line 1083, in op\n    return getattr(self.aval, f\"_{name}\")(self, *args)\nValueError: Invalid dtype for `swap`. Ref dtype: bfloat16. Value dtype: float32. \n",
      "model": "gpt53",
      "suite": "priority_kernels",
      "level": null,
      "eval_time_s": 9.8
    }
  ]
}