import einx._src.tracer as tracer
import einx._src.adapter as adapter
from ..api import api
from ..types import Tensor
import types
import inspect
import functools
from functools import partial
from ..backend import registry
from ..backend import Backend
from ..backend import OperationNotSupportedException
from ._util import _get_kwargnames
from ._util import _adapt_parameters

def _get_backend_kwargs():
    jnp = tracer.signature.python.import_("jax.numpy", as_="jnp")
    optimizations = [
        tracer.optimizer.classical.SkipReshape(jnp.reshape),
        tracer.optimizer.classical.SkipTranspose(jnp.transpose),
        tracer.optimizer.classical.SkipBroadcastTo(jnp.broadcast_to),
        tracer.optimizer.classical.SkipConcatenate(jnp.concatenate),
        tracer.optimizer.InlineGraph(),
        tracer.optimizer.SkipCast(),
    ]

    import jax.numpy as jnp
    return {
        "optimizations": optimizations,
        "compiler": tracer.compiler.python,
        "tensor_type": (jnp.ndarray, lambda x: tuple(int(x) for x in x.shape)),
    }

def _adapter(adapt):
    @functools.wraps(adapt)
    def inner(op=None, use_type_annotations=False, static_argnames=None):
        if op is None:
            return partial(inner, use_type_annotations=use_type_annotations, static_argnames=static_argnames)

        signature = inspect.signature(op)
        if "description" in signature.parameters:
            raise ValueError("The adapted function must not have a 'description' parameter.")

        op, parameters = adapt(op, parameters={**signature.parameters})

        return api(
            op,
            backend=types.SimpleNamespace(**_get_backend_kwargs()),
            signature=inspect.Signature(parameters=_adapt_parameters(parameters).values()),
            static_argnames=["parameters", "description"] + ([] if static_argnames is None else list(static_argnames)),
            use_type_annotations=use_type_annotations,
        )
    return inner

@_adapter
def adapt_with_vmap(op, parameters):
    jax = tracer.signature.jax()

    classical = adapter.classical_from_jax(jax)
    vmap = adapter.vmap_from_jax(jax)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_vmap.op(op, vmap, expected_type=jax.numpy.ndarray)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=jax.numpy.ndarray)
    op = adapter.einx_from_namedtensor.op(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters

@_adapter
def adapt_classical_reduce(op, parameters):
    if "axis" not in parameters:
        raise ValueError("The adapted function must have an 'axis' parameter, but none was found in its signature.")
    del parameters["axis"]

    jax = tracer.signature.jax()

    classical = adapter.classical_from_jax(jax)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_classical.reduce(op, expected_type=jax.numpy.ndarray)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=jax.numpy.ndarray)
    op = adapter.einx_from_namedtensor.reduce(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters

@_adapter
def adapt_classical_elementwise(op, parameters):
    jax = tracer.signature.jax()

    classical = adapter.classical_from_jax(jax)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_classical.elementwise(op, classical, expected_type=jax.numpy.ndarray)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=jax.numpy.ndarray)
    op = adapter.einx_from_namedtensor.elementwise(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters



def create():
    jax = tracer.signature.jax()

    classical = adapter.classical_from_jax(jax)
    einsum = adapter.einsum_from_jax(jax)

    decomposednamedtensor_ops = \
        {
            name: adapter.decomposednamedtensor_from_classical.elementwise(getattr(classical, name), classical)
            for name in adapter.ops.elementwise
        } | {
            name: adapter.decomposednamedtensor_from_classical.reduce(getattr(classical, name))
            for name in adapter.ops.reduce
        } | {
            name: adapter.decomposednamedtensor_from_classical.preserve_shape(getattr(classical, name))
            for name in adapter.ops.preserve_shape
        } | {
            name: adapter.decomposednamedtensor_from_classical.argfind(getattr(classical, name), classical)
            for name in adapter.ops.argfind
        } | {
            name: adapter.decomposednamedtensor_from_classical.update_at_ravelled(getattr(classical, name), classical)
            for name in adapter.ops.update_at
        } | {
            "get_at": adapter.decomposednamedtensor_from_classical.get_at_ravelled(classical),
            "dot": adapter.decomposednamedtensor_from_einsum.dot(einsum),
        }

    namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
    namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=jax.numpy.ndarray)
    einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

    backend = Backend(
        ops=einx_ops,
        name="jax.classical",
        priority=0,
        **_get_backend_kwargs(),
    )

    import jax.numpy as jnp
    return backend, [jnp.ndarray]

registry.register_on_import("jax", create)



def create():
    jax = tracer.signature.jax()

    classical = adapter.classical_from_jax(jax)
    vmap = adapter.vmap_from_jax(jax)

    def update_at(*args, **kwargs):
        raise OperationNotSupportedException("update_at operations are not supported by the jax.vmap backend.")

    elementary_ops = adapter.elementary_from_classical.ops(classical)

    decomposednamedtensor_ops = \
        {
            name: adapter.decomposednamedtensor_from_vmap.op(elementary_ops[name], vmap, expected_type=jax.numpy.ndarray, allow_squeeze_unsqueeze=True, classical=classical)
            for name in adapter.ops.all
        } | {
            name: update_at
            for name in adapter.ops.update_at
        }

    namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
    namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=jax.numpy.ndarray)
    einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

    backend = Backend(
        ops=einx_ops,
        name="jax.vmap",
        priority=0,
        **_get_backend_kwargs(),
    )

    return backend, []

registry.register_on_import("jax", create)