import einx._src.tracer as tracer
import einx._src.adapter as adapter
from ..api import api
from ..types import Tensor
from ..backend import registry
from ..backend import Backend

def _get_backend_kwargs():
    mx = tracer.signature.python.import_("mlx.core", as_="mx")
    optimizations = [
        tracer.optimizer.classical.SkipReshape(mx.reshape),
        tracer.optimizer.classical.SkipTranspose(mx.transpose),
        tracer.optimizer.classical.SkipBroadcastTo(mx.broadcast_to),
        tracer.optimizer.classical.SkipConcatenate(mx.concatenate),
        tracer.optimizer.InlineGraph(),
        tracer.optimizer.SkipCast(),
    ]

    import mlx.core as mx
    return {
        "optimizations": optimizations,
        "compiler": tracer.compiler.python,
        "tensor_type": (mx.array, lambda x: tuple(int(x) for x in x.shape)),
    }

def create():
    mlx = tracer.signature.mlx()

    classical = adapter.classical_from_mlx(mlx)
    einsum = adapter.einsum_detachscalars(adapter.einsum_from_mlx(mlx), classical)

    decomposednamedtensor_ops = \
        {
            name: adapter.decomposednamedtensor_from_classical.elementwise(getattr(classical, name), classical)
            for name in adapter.ops.elementwise
        } | {
            name: adapter.decomposednamedtensor_from_classical.reduce(getattr(classical, name))
            for name in adapter.ops.reduce
        } | {
            name: adapter.decomposednamedtensor_from_classical.preserve_shape(getattr(classical, name))
            for name in adapter.ops.preserve_shape
        } | {
            name: adapter.decomposednamedtensor_from_classical.argfind(getattr(classical, name), classical)
            for name in adapter.ops.argfind
        } | {
            name: adapter.decomposednamedtensor_from_classical.update_at_ravelled(getattr(classical, name), classical)
            for name in adapter.ops.update_at
        } | {
            "get_at": adapter.decomposednamedtensor_from_classical.get_at_ravelled(classical),
            "dot": adapter.decomposednamedtensor_from_einsum.dot(einsum),
        }

    namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
    namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=mlx.core.array)
    einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

    backend = Backend(
        ops=einx_ops,
        name="mlx.classical",
        priority=0,
        **_get_backend_kwargs(),
    )

    import mlx.core as mx
    return backend, [mx.array]

registry.register_on_import("mlx", create)
