import einx._src.tracer as tracer
from functools import partial
import numpy as np
import types
from .vmap import vmap
from .einsum import einsum

class functional:
    def __init__(self, functional):
        self._functional = functional
        self.softmax = partial(tracer.signature.numpy.preserve_shape, op=functional.softmax)
        self.log_softmax = partial(tracer.signature.numpy.preserve_shape, op=functional.log_softmax)

class torch:
    elementwise_op_names = [
        "add",
        "subtract",
        "multiply",
        "true_divide",
        "floor_divide",
        "divide",
        "remainder",
        "logical_and",
        "logical_or",
        "logical_xor",
        "where",
        "maximum",
        "minimum",
        "less",
        "less_equal",
        "greater",
        "greater_equal",
        "eq",
        "ne",
        "logaddexp",
        "exp",
        "log",
        "neg",
    ]
    reduce_op_names = [
        "sum",
        "mean",
        "var",
        "std",
        "prod",
        "count_nonzero",
        "all",
        "any",
        "amin",
        "amax",
        "logsumexp",
        "argmax",
        "argmin",
    ]
    preserve_shape_op_names = [
        "flip",
        "roll",
        "argsort",
    ]

    def __init__(self):
        self._torch = tracer.signature.python.import_("torch")
        self.einsum = einsum(self._torch.einsum)
        self.nn = types.SimpleNamespace(
            functional=functional(self._torch.nn.functional),
        )
        for name in self.elementwise_op_names:
            setattr(self, name, partial(self.elementwise, op=getattr(self._torch, name)))
        for name in self.reduce_op_names:
            setattr(self, name, partial(self.reduce, op=getattr(self._torch, name)))
        for name in self.preserve_shape_op_names:
            setattr(self, name, partial(tracer.signature.numpy.preserve_shape, op=getattr(self._torch, name)))
            
        self.vmap = lambda func, in_dims, out_dims: vmap(lambda func, in_axes, out_axes: self._torch.vmap(
            func, in_dims=in_axes, out_dims=out_axes,
        ))(
            func, in_axes=in_dims, out_axes=out_dims,
        )

    def __getattr__(self, name):
        return getattr(self._torch, name)

    def asarray(self, x, *, device):
        if isinstance(x, (tracer.signature.classical.Tensor, tracer.signature.classical.ConvertibleTensor)):
            shape = x.shape
        elif isinstance(x, (int, float, bool, np.integer, np.floating, np.bool_)):
            shape = ()
        else:
            raise ValueError(f"Unsupported type {type(x)} for asarray operation")
        x = self._torch.asarray(x, device=device)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def sort(self, x, *args, **kwargs):
        shape = x.shape
        x = self._torch.sort(x, *args, **kwargs)[0]
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def reshape(self, x, shape):
        x = self._torch.reshape(x, shape)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def permute(self, x, axes):
        shape = tuple(x.shape[i] for i in axes)

        x = self._torch.permute(x, axes)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def broadcast_to(self, x, shape):
        x = self._torch.broadcast_to(x, shape)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    @staticmethod
    def reduce(x, op, **kwargs):
        allowed_kwargs = {"dim", "keepdim"}
        for k in kwargs:
            if k not in allowed_kwargs:
                raise TypeError(f"Unexpected keyword argument '{k}' in reduce operation")
        shape = list(x.shape)
        if "dim" not in kwargs or kwargs["dim"] is None:
            axes = list(range(x.ndim))
        elif isinstance(kwargs["dim"], int):
            axes = [kwargs["dim"]]
        else:
            axes = kwargs["dim"]

        if "keepdim" in kwargs and kwargs["keepdim"]:
            for a in axes:
                shape[a] = 1
        else:
            for a in sorted(axes, reverse=True):
                del shape[a]
        shape = tuple(shape)

        x = op(x, **kwargs)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    @staticmethod
    def elementwise(*xs, op):
        shape = None
        for a in xs:
            if hasattr(a, "shape"):
                if shape is None:
                    shape = a.shape
                else:
                    shape2 = a.shape
                    while len(shape) < len(shape2):
                        shape = (1,) + shape
                    while len(shape2) < len(shape):
                        shape2 = (1,) + shape2
                    shape = np.maximum(shape, shape2)
        if shape is None:
            raise ValueError("elementwise operation requires at least one tensor as argument")

        x = op(*xs)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def take(self, x, index):
        shape = index.shape

        x = self._torch.take(x, index)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def getitem(self, x, key):
        if isinstance(key, tuple):
            keys = key
        else:
            keys = (key,)

        if len(keys) > x.ndim:
            raise ValueError(f"Too many indices for tensor of dimension {x.ndim}")
        elif len(keys) == 0:
            raise ValueError("Empty index tuple")

        in_shape = list(x.shape)
        shape = []
        for k in keys:
            if isinstance(k, (np.integer, int)) or (hasattr(k, "ndim") and k.ndim == 0):
                in_shape = in_shape[1:]
            elif k == slice(None):
                shape.append(in_shape[0])
                in_shape = in_shape[1:]
            elif k is None:
                shape.append(1)
            else:
                raise NotImplementedError(f"Key type {type(k)} not supported")
        shape = tuple(shape) + tuple(in_shape)

        x = tracer.signature.python.getitem(x, key)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def index_put_(self, x, key, value, *, accumulate):
        shape = x.shape
        x = tracer.signature.python.getattr(x, "index_put_")((key,), value, accumulate=accumulate)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def arange(self, n, dtype, *, device):
        shape = (n,)

        x = self._torch.arange(n, dtype=dtype, device=device)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def split(self, x, split_size_or_sections, dim=0):
        if dim < 0:
            dim += x.ndim
        if dim < 0 or dim >= x.ndim:
            raise ValueError(f"axis {dim} out of bounds for array of dimension {x.ndim}")

        if isinstance(split_size_or_sections, int):
            section_length = x.shape[dim] // split_size_or_sections
            if section_length * split_size_or_sections != x.shape[0]:
                raise ValueError(
                    f"array split does not result in an equal division: "
                    f"{x.shape[dim]} % {split_size_or_sections} != 0"
                )
            lengths = [section_length] * split_size_or_sections
        else:
            lengths = split_size_or_sections

        shapes = []
        for length in lengths:
            shape = list(x.shape)
            shape[dim] = int(length)
            shapes.append(tuple(shape))

        xs = self._torch.split(x, split_size_or_sections, dim=dim)
        xs = tracer.cast(xs, lambda origin: list(tracer.signature.classical.Tensor(origin, shape=shape) for shape in shapes))
        return xs

    def cat(self, xs, dim=0):
        if dim < 0:
            dim += xs[0].ndim
        if dim < 0 or dim >= xs[0].ndim:
            raise ValueError(f"dim {dim} out of bounds for array of dimension {xs[0].ndim}")

        shape = list(xs[0].shape)
        shape[dim] = sum(x.shape[dim] for x in xs)
        shape = tuple(shape)

        x = self._torch.cat(xs, dim=dim)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def dot(self, x, y):
        if x.ndim != 1 or y.ndim != 1:
            raise ValueError(
                f"dot only supports 1D tensors, got {x.ndim}D and {y.ndim}D"
            )

        x = self._torch.dot(x, y)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=tuple()))
        return x