import einx._src.tracer as tracer
from functools import partial
from .vmap import vmap
from .einsum import einsum

class nn:
    def __init__(self, nn):
        self._nn = nn
        self.softmax = partial(tracer.signature.numpy.preserve_shape, op=nn.softmax)
        self.log_softmax = partial(tracer.signature.numpy.preserve_shape, op=nn.log_softmax)

class core:
    def __init__(self, mx):
        self._mx = mx

        _mx_as_numpy = tracer.signature.numpy(self._mx)
        for name in tracer.signature.numpy.elementwise_op_names + tracer.signature.numpy.reduce_op_names + ["roll", "sort", "argsort", "reshape", "transpose", "split", "broadcast_to", "divmod", "take", "getitem", "arange", "concatenate", "dot"]:
            setattr(self, name, getattr(_mx_as_numpy, name))

        self.vmap = vmap(self._mx.vmap)
        self.einsum = einsum(self._mx.einsum)
        self.array = self._mx.array

        for dtype in ["int32", "int64", "float32", "float64"]:
            setattr(self, dtype, getattr(self._mx, dtype))
        setattr(self, "bool", self._mx.bool_)

    def setitem(self, x, indices, values):
        tracer_type = x._tracer_type
        x = tracer.signature.python.setitem(x, indices, values)
        x = tracer.cast(x, tracer_type)
        return x

    def at(self, x, indices, updates, *, op):
        tracer_type = x._tracer_type
        x = tracer.signature.python.getattr(tracer.signature.python.getattr(x, "at")[indices], op)(updates)
        x = tracer.cast(x, tracer_type)
        return x

class mlx:
    def __init__(self):
        self.core = core(tracer.signature.python.import_("mlx.core", as_="mx"))
        self.nn = nn(tracer.signature.python.import_("mlx.nn", as_="mnn"))
