from ..backend import registry
from ..backend import Backend
import einx._src.adapter as adapter
import einx._src.tracer as tracer
import numpy as np

numpy = tracer.signature.numpy()

classical = adapter.classical_from_numpy(numpy)
einsum = adapter.einsum_from_numpy(numpy)

decomposednamedtensor_ops = \
    {
        name: adapter.decomposednamedtensor_from_classical.elementwise(getattr(classical, name), classical)
        for name in adapter.ops.elementwise
    } | {
        name: adapter.decomposednamedtensor_from_classical.reduce(getattr(classical, name))
        for name in adapter.ops.reduce
    } | {
        name: adapter.decomposednamedtensor_from_classical.preserve_shape(getattr(classical, name))
        for name in adapter.ops.preserve_shape
    } | {
        name: adapter.decomposednamedtensor_from_classical.argfind(getattr(classical, name), classical)
        for name in adapter.ops.argfind
    } | {
        name: adapter.decomposednamedtensor_from_classical.update_at_ravelled(getattr(classical, name), classical)
        for name in adapter.ops.update_at
    } | {
        "get_at": adapter.decomposednamedtensor_from_classical.get_at_ravelled(classical),
        "dot": adapter.decomposednamedtensor_from_einsum.dot(einsum),
    }

namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=numpy.ndarray)
einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

numpy = tracer.signature.python.import_("numpy", as_="np")
optimizations = [
    tracer.optimizer.classical.SkipReshape(numpy.reshape),
    tracer.optimizer.classical.SkipTranspose(numpy.transpose),
    tracer.optimizer.classical.SkipBroadcastTo(numpy.broadcast_to),
    tracer.optimizer.classical.SkipConcatenate(numpy.concatenate),
    tracer.optimizer.InlineGraph(),
    tracer.optimizer.SkipCast(),
]

backend = Backend(
    ops=einx_ops,
    name="numpy",
    priority=-1,
    optimizations=optimizations,
    compiler=tracer.compiler.python,
    tensor_type=(np.ndarray, lambda x: tuple(int(x) for x in x.shape)),
)

registry.register(backend, [np.ndarray, list, tuple, int, float, bool, np.integer, np.floating, np.bool_])