import einx._src.tracer as tracer
from functools import partial
from .vmap import vmap

class nn:
    def __init__(self, nn):
        self._nn = nn
        self.logsumexp = partial(tracer.signature.numpy.reduce, op=nn.logsumexp)
        self.softmax = partial(tracer.signature.numpy.preserve_shape, op=nn.softmax)
        self.log_softmax = partial(tracer.signature.numpy.preserve_shape, op=nn.log_softmax)

class jax:
    def __init__(self):
        traced_jax = tracer.signature.python.import_("jax")
        traced_jnp = tracer.signature.python.import_("jax.numpy", as_="jnp")
        self.numpy = tracer.signature.numpy(traced_jnp)
        self.vmap = vmap(traced_jax.vmap)
        self.nn = nn(traced_jax.nn)

    def at(self, x, indices, updates, *, op):
        tracer_type = x._tracer_type
        x = tracer.signature.python.getattr(tracer.signature.python.getattr(x, "at")[indices], op)(updates)
        x = tracer.cast(x, tracer_type)
        return x