import einx._src.tracer as tracer
import einx._src.adapter as adapter
from ..api import api
import types
import inspect
import functools
from functools import partial
from ..backend import registry
from ..backend import Backend
from ..backend import OperationNotSupportedException
import threading
from ._util import _get_kwargnames
from ._util import _adapt_parameters

def _raise_on_invalid_version():
    import torch
    version = tuple(int(i) for i in torch.__version__.split(".")[:2])
    if version < (2, 0):
        raise OperationNotSupportedException(
            "einx with PyTorch requires PyTorch version >= 2, but found "
            f"{torch.__version__}. einx functions are disabled for PyTorch."
        )

_has_allowed_in_graph = False
_has_allowed_in_graph_lock = threading.Lock()
def _allow_ops_in_graph():
    global _has_allowed_in_graph
    if not _has_allowed_in_graph:
        with _has_allowed_in_graph_lock:
            if not _has_allowed_in_graph:
                import torch
                from einx._src.frontend.ops import ops
                for op in ops:
                    torch.compiler.allow_in_graph(op) # TODO: older torch versions?
                _has_allowed_in_graph = True

def _get_backend_kwargs():
    torch = tracer.signature.python.import_("torch")
    optimizations = [
        tracer.optimizer.classical.SkipReshape(torch.reshape),
        tracer.optimizer.classical.SkipTranspose(torch.permute),
        tracer.optimizer.classical.SkipBroadcastTo(torch.broadcast_to),
        tracer.optimizer.classical.SkipConcatenate(torch.cat),
        tracer.optimizer.InlineGraph(),
        tracer.optimizer.SkipCast(),
    ]

    import torch
    return {
        "optimizations": optimizations,
        "compiler": tracer.compiler.python,
        "tensor_type": (torch.Tensor, lambda x: tuple(int(x) for x in x.shape)),
    }

def _adapter(adapt):
    @functools.wraps(adapt)
    def inner(op=None, use_type_annotations=False, static_argnames=None):
        if op is None:
            return partial(inner, use_type_annotations=use_type_annotations, static_argnames=static_argnames)

        _raise_on_invalid_version()

        signature = inspect.signature(op)
        if "description" in signature.parameters:
            raise ValueError("The adapted function must not have a 'description' parameter.")

        op, parameters = adapt(op, parameters={**signature.parameters})

        func = api(
            op,
            backend=types.SimpleNamespace(**_get_backend_kwargs()),
            signature=inspect.Signature(parameters=_adapt_parameters(parameters).values()),
            static_argnames=["parameters", "description"] + ([] if static_argnames is None else list(static_argnames)),
            use_type_annotations=use_type_annotations,
        )
        import torch
        torch.compiler.allow_in_graph(func)
        return func
    return inner

@_adapter
def adapt_with_vmap(op, parameters):
    device_stack = adapter.TorchDeviceStack()

    torch = tracer.signature.torch()
    torch = adapter.torchautocast_from_torch(torch, device_stack.get_device)

    classical = adapter.classical_from_torch(torch)
    vmap = adapter.vmap_from_torch(torch)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_vmap.op(op, vmap, expected_type=torch.Tensor)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = device_stack.namedtensor.op(op)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=torch.Tensor)
    op = adapter.einx_from_namedtensor.op(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters

@_adapter
def adapt_classical_reduce(op, parameters):
    if "axis" not in parameters:
        raise ValueError("The adapted function must have an 'axis' parameter, but none was found in its signature.")
    del parameters["axis"]

    device_stack = adapter.TorchDeviceStack()

    torch = tracer.signature.torch()
    torch = adapter.torchautocast_from_torch(torch, device_stack.get_device)

    classical = adapter.classical_from_torch(torch)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_classical.reduce(op, expected_type=torch.Tensor)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = device_stack.namedtensor.op(op)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=torch.Tensor)
    op = adapter.einx_from_namedtensor.reduce(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters

@_adapter
def adapt_classical_elementwise(op, parameters):
    device_stack = adapter.TorchDeviceStack()

    torch = tracer.signature.torch()
    torch = adapter.torchautocast_from_torch(torch, device_stack.get_device)

    classical = adapter.classical_from_torch(torch)

    op = tracer.signature.python.constant(op)
    op = adapter.decomposednamedtensor_from_classical.elementwise(op, classical, expected_type=torch.Tensor)
    op = adapter.namedtensor_from_decomposednamedtensor.op(op, classical)
    op = device_stack.namedtensor.op(op)
    op = adapter.namedtensor_calltensorfactory.op(op, expected_type=torch.Tensor)
    op = adapter.einx_from_namedtensor.elementwise(op, kwargnames=_get_kwargnames(parameters))

    return op, parameters



def create():
    _raise_on_invalid_version()
    _allow_ops_in_graph()

    device_stack = adapter.TorchDeviceStack()

    torch = tracer.signature.torch()
    torch = adapter.torchautocast_from_torch(torch, device_stack.get_device)

    classical = adapter.classical_from_torch(torch)
    einsum = adapter.einsum_from_torch(torch)

    decomposednamedtensor_ops = \
        {
            name: adapter.decomposednamedtensor_from_classical.elementwise(getattr(classical, name), classical)
            for name in adapter.ops.elementwise
        } | {
            name: adapter.decomposednamedtensor_from_classical.reduce(getattr(classical, name))
            for name in adapter.ops.reduce
        } | {
            name: adapter.decomposednamedtensor_from_classical.preserve_shape(getattr(classical, name))
            for name in adapter.ops.preserve_shape
        } | {
            name: adapter.decomposednamedtensor_from_classical.argfind(getattr(classical, name), classical)
            for name in adapter.ops.argfind
        } | {
            name: adapter.decomposednamedtensor_from_classical.update_at_ravelled(getattr(classical, name), classical)
            for name in adapter.ops.update_at
        } | {
            "get_at": adapter.decomposednamedtensor_from_classical.get_at_ravelled(classical),
            "dot": adapter.decomposednamedtensor_from_einsum.dot(einsum),
        }

    namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
    namedtensor_ops = device_stack.namedtensor.ops(namedtensor_ops)
    namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=torch.Tensor)
    einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

    backend = Backend(
        ops=einx_ops,
        name="torch.classical",
        priority=0,
        **_get_backend_kwargs(),
    )

    import torch
    return backend, [torch.Tensor]

registry.register_on_import("torch", create)



def create():
    _raise_on_invalid_version()
    _allow_ops_in_graph()

    device_stack = adapter.TorchDeviceStack()

    torch = tracer.signature.torch()
    torch = adapter.torchautocast_from_torch(torch, device_stack.get_device)

    classical = adapter.classical_from_torch(torch)
    vmap = adapter.vmap_from_torch(torch)

    def get_at(*args, **kwargs):
        torch_error = "vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report."
        message = "get_at is not supported by the torch.vmap backend. As of testing this (with PyTorch 2.7.0), " \
                f"torch.vmap is not compatible with scalar indexing operations and raises the following error:\n\"{torch_error}\"\n" \
                "Please use another PyTorch backend for this operation." # TODO: e.g. torch.classical or torch.nameddims
        raise OperationNotSupportedException(message)

    def update_at(*args, **kwargs):
        raise OperationNotSupportedException("update_at operations are not supported by the torch.vmap backend.")

    elementary_ops = adapter.elementary_from_classical.ops(classical)

    decomposednamedtensor_ops = \
        {
            name: adapter.decomposednamedtensor_from_vmap.op(elementary_ops[name], vmap, expected_type=torch.Tensor, allow_squeeze_unsqueeze=True, classical=classical)
            for name in adapter.ops.all
        } | {
            "get_at": get_at,
        } | {
            name: update_at
            for name in adapter.ops.update_at
        }

    namedtensor_ops = adapter.namedtensor_from_decomposednamedtensor.ops(decomposednamedtensor_ops, classical)
    namedtensor_ops = device_stack.namedtensor.ops(namedtensor_ops)
    namedtensor_ops = adapter.namedtensor_calltensorfactory.ops(namedtensor_ops, expected_type=torch.Tensor)
    einx_ops = adapter.einx_from_namedtensor.ops(namedtensor_ops)

    backend = Backend(
        ops=einx_ops,
        name="torch.vmap",
        priority=0,
        **_get_backend_kwargs(),
    )

    return backend, []

registry.register_on_import("torch", create)
