import einx._src.tracer as tracer
import numpy as np
import threading
from functools import partial
import functools
import types
from einx._src.util.functools import use_name_of

def _elementwise_to_tensor(self, op, allow_scalars=None):
    @use_name_of(op)
    def wrapper(*xs):
        # Convert all non-scalar inputs to tensors
        def _nonscalar_to_tensor(x):
            if isinstance(x, tracer.signature.classical.ConvertibleTensor) and not issubclass(x.concrete.type, (int, float, bool, np.integer, np.floating, np.bool_)):
                return self._to_tensor(x)
            else:
                return x
        xs = [_nonscalar_to_tensor(x) for x in xs]

        if allow_scalars == "none":
            xs = [self._to_tensor(x) for x in xs]
        elif allow_scalars == "not-all":
            if not any(isinstance(x, tracer.signature.classical.Tensor) for x in xs):
                # All are scalars -> convert first to tensor
                xs[0] = self._to_tensor(xs[0])
        else:
            assert len(allow_scalars) > 0
            for i in range(len(xs)):
                if i not in allow_scalars and isinstance(xs[i], tracer.signature.classical.ConvertibleTensor):
                    # Is a scalar, but not allowed -> convert to tensor
                    xs[i] = self._to_tensor(xs[i])

        return op(*xs)

    return wrapper

class torchautocast_from_torch:
    def __init__(self, torch, _get_device):
        self._torch = torch
        self._get_device = _get_device

        valid_tensor_types = [tracer.signature.classical.Tensor]
        if isinstance(self._torch.Tensor, type):
            valid_tensor_types.append(self._torch.Tensor)
        self.valid_tensor_types = tuple(valid_tensor_types)

        for name in tracer.signature.torch.reduce_op_names:
            setattr(self, name, partial(self._reduce, op=name))

        elementwise_name_to_op = {
            "add": _elementwise_to_tensor(self, self._torch.add, allow_scalars="not-all"),
            "subtract": _elementwise_to_tensor(self, self._torch.subtract, allow_scalars="not-all"),
            "multiply": _elementwise_to_tensor(self, self._torch.multiply, allow_scalars="not-all"),
            "true_divide": _elementwise_to_tensor(self, self._torch.true_divide, allow_scalars="not-all"),
            "floor_divide": _elementwise_to_tensor(self, self._torch.floor_divide, allow_scalars="not-all"),
            "divide": _elementwise_to_tensor(self, self._torch.divide, allow_scalars="not-all"),
            "remainder": _elementwise_to_tensor(self, self._torch.remainder, allow_scalars="not-all"),

            "logaddexp": _elementwise_to_tensor(self, self._torch.divide, allow_scalars="none"),

            "logical_and": _elementwise_to_tensor(self, self._torch.logical_and, allow_scalars="none"),
            "logical_or": _elementwise_to_tensor(self, self._torch.logical_or, allow_scalars="none"),
            "logical_xor": _elementwise_to_tensor(self, self._torch.logical_or, allow_scalars="none"),
            "where": _elementwise_to_tensor(self, self._torch.where, allow_scalars=[1, 2]),
            "maximum": _elementwise_to_tensor(self, self._torch.maximum, allow_scalars="none"),
            "minimum": _elementwise_to_tensor(self, self._torch.minimum, allow_scalars="none"),
    
            "less": _elementwise_to_tensor(self, self._torch.less, allow_scalars=[1]),
            "less_equal": _elementwise_to_tensor(self, self._torch.less_equal, allow_scalars=[1]),
            "greater": _elementwise_to_tensor(self, self._torch.greater, allow_scalars=[1]),
            "greater_equal": _elementwise_to_tensor(self, self._torch.greater_equal, allow_scalars=[1]),
            "eq": _elementwise_to_tensor(self, self._torch.eq, allow_scalars=[1]),
            "ne": _elementwise_to_tensor(self, self._torch.ne, allow_scalars=[1]),

            "exp": _elementwise_to_tensor(self, self._torch.exp, allow_scalars="none"),
            "log": _elementwise_to_tensor(self, self._torch.log, allow_scalars="none"),
            "neg": _elementwise_to_tensor(self, self._torch.neg, allow_scalars="none"),
        }
        for name in tracer.signature.torch.elementwise_op_names:
            setattr(self, name, elementwise_name_to_op[name])

    def __getattr__(self, name):
        return getattr(self._torch, name)

    def _to_tensor(self, x):
        if isinstance(x, (int, float, bool, np.integer, np.floating, np.bool_)):
            return self.asarray(x)
        elif isinstance(x, tracer.signature.classical.ConvertibleTensor):
            if issubclass(x.concrete.type, (int, float, bool, np.integer, np.floating, np.bool_, list, tuple, np.ndarray)):
                return self.asarray(x)
            else:
                raise ValueError(f"An object of type {x.concrete.type} cannot be used as a tensor in PyTorch") # TODO:
        elif isinstance(x, self.valid_tensor_types):
            return x
        else:
            raise ValueError(f"Expected a tensor, but got {type(x)}")

    def asarray(self, x):
        return self._torch.asarray(x, device=self._get_device())

    def reshape(self, x, shape):
        x = self._to_tensor(x)
        return self._torch.reshape(x, shape)

    def permute(self, x, dims):
        x = self._to_tensor(x)
        return self._torch.permute(x, dims)

    def broadcast_to(self, x, shape):
        x = self._to_tensor(x)
        return self._torch.broadcast_to(x, shape)

    def _reduce(self, x, op, **kwargs):
        if isinstance(op, str):
            op = getattr(self._torch, op)
        x = self._to_tensor(x)
        return op(x, **kwargs)

    def take(self, x, index):
        x = self._to_tensor(x)
        index = self._to_tensor(index)
        return self._torch.take(x, index)

    def getitem(self, x, key):
        x = self._to_tensor(x)

        def key_to_tensor(key):
            if isinstance(key, tracer.signature.classical.ConvertibleTensor) and not issubclass(x.concrete.type, (int, np.integer)):
                return self.asarray(key)
            else:
                return key
        if isinstance(key, tuple):
            key = tuple(key_to_tensor(k) for k in key)
        else:
            key = key_to_tensor(key)

        return self._torch.getitem(x, key)

    def index_put_(self, x, key, value, *, accumulate):
        x = self._to_tensor(x)
        value = self._to_tensor(value)

        if isinstance(key, tuple):
            key = tuple(self._to_tensor(k) for k in key)
        else:
            key = self._to_tensor(key)

        return self._torch.index_put_(x, key, value, accumulate=accumulate)

    def arange(self, n, dtype):
        return self._torch.arange(n, dtype=dtype, device=self._get_device())

    def split(self, x, split_size_or_sections, dim=0):
        x = self._to_tensor(x)
        return self._torch.split(x, split_size_or_sections, dim=dim)

    def cat(self, xs, dim=0):
        xs = [self._to_tensor(x) for x in xs]
        return self._torch.cat(xs, dim=dim)

    def dot(self, x, y):
        x = self._to_tensor(x)
        y = self._to_tensor(y)
        return self._torch.dot(x, y)

    def einsum(self, subscripts, *operands):
        operands = [self._to_tensor(x) for x in operands]
        return self._torch.einsum(subscripts, *operands)

    def vmap(self, func, in_dims, out_dims):
        vmapped_func = self._torch.vmap(func, in_dims=in_dims, out_dims=out_dims)
        def vmapped_func_with_conversions(*args):
            args = [self._to_tensor(arg) for arg in args]
            return vmapped_func(*args)
        return vmapped_func_with_conversions

class TorchDeviceStack:
    def __init__(self):
        self._thread_local = threading.local()
        self.namedtensor = types.SimpleNamespace(
            op=self._wrap_namedtensor_op,
            ops=self._wrap_namedtensor_ops,
        )

    def get_device(self):
        stack = self._get_stack()
        assert len(stack) > 0
        device = stack[-1]
        return device

    def _wrap_namedtensor_op(self, op):
        @use_name_of(op)
        def inner(*tensors, out, **kwargs):
            self._enter([t.value for t in tensors])
            try:
                return op(*tensors, out=out, **kwargs)
            finally:
                self._exit([t.value for t in tensors])
        return inner

    def _wrap_namedtensor_ops(self, ops):
        return {name: self._wrap_namedtensor_op(op) for name, op in ops.items()}

    def _get_stack(self):
        if not hasattr(self._thread_local, "stack"):
            self._thread_local.stack = []
        return self._thread_local.stack

    def _enter(self, tensors):
        device = None
        for tensor in tensors:
            if isinstance(tensor, tracer.signature.classical.Tensor):
                device = tracer.signature.python.getattr(tensor, "device")
                break
        if device is None:
            raise ValueError("Failed to determine the PyTorch device placement of parameters. Maybe convert the given arguments to a tensor first.")
        stack = self._get_stack()
        stack.append(device)

    def _exit(self, tensors):
        stack = self._get_stack()
        assert len(stack) > 0
        stack.pop()
