# Task

You are optimizing a JAX kernel for TPU v6e (Trillium) using the Pallas programming model (`jax.experimental.pallas`).

Target hardware: TPU v6e. VMEM = 128 MiB, 1 TensorCore per chip, peak ~918 TFLOPS bf16, HBM ~1526 GB/s.

## TPU/JAX/Pallas Reference links

- https://docs.jax.dev/en/latest/pallas/index.html
- https://docs.jax.dev/en/latest/pallas/tpu/index.html
- https://docs.jax.dev/en/latest/jax.experimental.pallas.tpu.html
- https://docs.cloud.google.com/tpu/docs/

## Files

- `baseline.py` — the XLA/JAX reference implementation. Rewrite `workload()` as a Pallas kernel that is correct and faster.
- `solution.py` — your working file. Edit this to develop your implementation.
- `eval.sh` — evaluates `solution.py` on real TPU hardware. Run `bash eval.sh` to test.

## Rules

1. Keep the public `workload(*inputs)` function signature unchanged.
2. `solution.py` must be a complete, self-contained Python file.
3. You have a budget of **{budget}** evaluations (`eval.sh` calls). After that, `eval.sh` will refuse to run.
4. Your goal: produce the fastest correct Pallas implementation.

## Strategy

If you do not yet have a correct Pallas implementation, your first priority is to produce a correct, straightforward translation — even if it isn't faster than the XLA baseline. Only after you have a correct Pallas kernel should you focus on optimizing it for speed.

Begin by reading `baseline.py`, then create your initial `solution.py`.
