import jax
from jax import pmap
import jax.numpy as np

a = np.array([[1.0] * 4, [2.0] * 4, [3.0] * 4])
b = np.array([[3.0] * 3, [5.0] * 3])
print(a)
print(a.shape)
print(b)
print(b.shape)
print(b @ a)

b= b.reshape(2, 1, 3)
y = pmap(lambda x: b @ a, axis_name='i')(b)
print(y)
print(jax.lax.pmean(y, axis_name='i'))