import einx._src.tracer as tracer
import einx._src.adapter as adapter
from ..api import api
import types
import inspect
from ..backend import registry
from ..backend import Backend
from ..backend import OperationNotSupportedException
import threading
from ._util import _get_kwargnames
from ._util import _adapt_parameters

from .torch import _raise_on_invalid_version
from .torch import _allow_ops_in_graph

def _get_backend_kwargs():
    optimizations = [
        tracer.optimizer.InlineGraph(),
        tracer.optimizer.SkipCast(),
    ]

    import torch
    return {
        "optimizations": optimizations,
        "compiler": tracer.compiler.python,
        "tensor_type": (torch.Tensor, lambda x: tuple(int(x) for x in x.shape)),
    }

def adapt(op):
    _raise_on_invalid_version()

    signature = inspect.signature(op)
    if "description" in signature.parameters:
        raise ValueError("The adapted function must not have a 'description' parameter.")
    if "axis" not in signature.parameters:
        raise ValueError("The adapted function must have an 'axis' parameter, but none was found in its signature.")
    parameters = {**signature.parameters}
    del parameters["axis"]

    torch = tracer.signature.torch()
    functorchdim = tracer.signature.python.import_("functorch.dim", as_="ftdim")

    op = tracer.signature.python.constant(op)
    op = adapter.namedtensor_from_functorchdim.op(op, functorchdim)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=torch.Tensor)
    op = adapter.einx_from_namedtensor.op(op, kwargnames=_get_kwargnames(parameters))

    func = api(
        op,
        backend=types.SimpleNamespace(**_get_backend_kwargs()),
        signature=inspect.Signature(parameters=_adapt_parameters(parameters).values()),
    )
    import torch
    torch.compiler.allow_in_graph(func)
    return func



# def create():
#     _raise_on_invalid_version()
#     _allow_ops_in_graph()

#     device_stack = adapter.TorchDeviceStack() # TODO: this curretly isnt used by the ops

#     torch = tracer.signature.python.import_("torch")
#     functorchdim = tracer.signature.python.import_("functorch.dim", as_="ftdim")

#     def unsupported(*args, **kwargs):
#         raise OperationNotSupportedException()

#     namedtensor_ops = \
#         {
#             name: unsupported
#             for name in adapter.ops.all
#         } | {
#             name: adapter.namedtensor_from_functorchdim.elementwise(getattr(torch, name), functorchdim)
#             for name in adapter.ops.elementwise
#         } | {
#             name: adapter.namedtensor_from_functorchdim.reduce(getattr(functorchdim.Tensor, name), functorchdim)
#             for name in adapter.ops.reduce
#         }

#     namedtensor_ops = device_stack.namedtensor.ops(namedtensor_ops)
#     namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=torch.Tensor)
#     einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

#     backend = Backend(
#         ops=einx_ops,
#         name="functorchdim",
#         priority=0,
#         **_get_backend_kwargs(),
#     )

#     return backend, []

# registry.register_on_import("functorch", create)