"""Sparse Mixture of Experts (MoE) — Mixtral 8x7B with Pallas kernel."""
import jax
import jax.numpy as jnp
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

CONFIG = {
    'name': 'mixtral_8x7b_moe',
    'model': 'Mixtral-8x7B',
    'operator': 'sparse_moe',
    'batch': 1,
    'seq_len': 2048,
    'emb_dim': 4096,
    'mlp_dim': 14336,
    'num_experts': 8,
    'num_experts_per_tok': 2,
}


def create_inputs(dtype=jnp.bfloat16):
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, 5)
    B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim']
    N = CONFIG['num_experts']
    x = jax.random.normal(keys[0], (B, S, E), dtype=dtype)
    router = jax.random.normal(keys[1], (E, N), dtype=dtype) * 0.02
    gate_k = jax.random.normal(keys[2], (N, E, M), dtype=dtype) * 0.02
    up_k = jax.random.normal(keys[3], (N, E, M), dtype=dtype) * 0.02
    down_k = jax.random.normal(keys[4], (N, M, E), dtype=dtype) * 0.02
    return x, router, gate_k, up_k, down_k


def combine_kernel(expert_out_ref, weight_ref, out_ref):
    # expert_out_ref: [BS_block, N, E_block]
    # weight_ref: [BS_block, N]
    expert_out = expert_out_ref[...]  # (bs, N, e)
    weights = weight_ref[...]         # (bs, N)

    weights = weights[..., None]      # (bs, N, 1)
    acc = jnp.sum(expert_out * weights, axis=1)  # (bs, e)
    out_ref[...] = acc


def workload(x, router_weights, expert_gate_kernels, expert_up_kernels, expert_down_kernels):
    B, S, E = x.shape
    N = router_weights.shape[-1]
    K = CONFIG['num_experts_per_tok']

    logits = jnp.dot(x, router_weights)
    top_k_logits, top_k_indices = jax.lax.top_k(logits, K)
    router_probs = jax.nn.softmax(top_k_logits, axis=-1)

    gate_out = jax.nn.silu(jnp.einsum('bse,nem->bsnm', x, expert_gate_kernels))
    up_out = jnp.einsum('bse,nem->bsnm', x, expert_up_kernels)
    hidden = gate_out * up_out
    expert_outputs = jnp.einsum('bsnm,nme->bsne', hidden, expert_down_kernels)

    one_hot = jax.nn.one_hot(top_k_indices, N)
    weighted = one_hot * router_probs[..., None]
    expert_weights = weighted.sum(axis=2)  # (B,S,N)

    BS = B * S
    expert_outputs_2d = expert_outputs.reshape(BS, N, E)
    expert_weights_2d = expert_weights.reshape(BS, N)

    bs_block = 128
    e_block = 128

    bs_block = min(bs_block, BS)
    e_block = min(e_block, E)

    grid = (BS // bs_block, E // e_block)

    out = pl.pallas_call(
        combine_kernel,
        out_shape=jax.ShapeDtypeStruct((BS, E), x.dtype),
        grid_spec=pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            grid=grid,
            in_specs=[
                pl.BlockSpec((bs_block, N, e_block), lambda i, j: (i, 0, j)),
                pl.BlockSpec((bs_block, N), lambda i, j: (i, 0)),
            ],
            out_specs=pl.BlockSpec((bs_block, e_block), lambda i, j: (i, j)),
        ),
    )(expert_outputs_2d, expert_weights_2d)

    return out.reshape(B, S, E)


def benchmark(num_warmup=5, num_iters=100):
    import time
    inputs = create_inputs()
    fn = jax.jit(workload)
    for _ in range(num_warmup):
        out = fn(*inputs)
        out.block_until_ready()
    times = []
    for _ in range(num_iters):
        t0 = time.perf_counter()
        out = fn(*inputs)
        out.block_until_ready()
        times.append(time.perf_counter() - t0)
    import numpy as np
    times = np.array(times) * 1000
    B, S, E, M = CONFIG['batch'], CONFIG['seq_len'], CONFIG['emb_dim'], CONFIG['mlp_dim']
    K, N = CONFIG['num_experts_per_tok'], CONFIG['num_experts']
    routing_flops = B * S * E * N * 2
    expert_flops = B * S * K * (E * M * 2 * 3)
    flops = routing_flops + expert_flops
    avg = float(np.mean(times))
    return {
        'name': CONFIG['name'],
        'model': CONFIG['model'],
        'operator': CONFIG['operator'],
        'config': {k: v for k, v in CONFIG.items() if k not in ('name', 'model', 'operator')},
        'time_ms': round(avg, 4),
        'std_ms': round(float(np.std(times)), 4),
        'tflops': round(flops / (avg / 1000) / 1e12, 2),
        'output_shape': list(out.shape),
        'status': 'success',
    }


if __name__ == '__main__':
    import json
    print(json.dumps(benchmark()))
