import einx._src.tracer as tracer
from functools import partial
import numpy as np
from .einsum import einsum

class numpy:
    elementwise_op_names = [
        "add",
        "subtract",
        "multiply",
        "true_divide",
        "floor_divide",
        "divide",
        "logical_and",
        "logical_or",
        "logical_xor",
        "where",
        "maximum",
        "minimum",
        "less",
        "less_equal",
        "greater",
        "greater_equal",
        "equal",
        "not_equal",
        "logaddexp",
        "exp",
        "log",
        "negative",
    ]
    reduce_op_names = [
        "sum",
        "mean",
        "var",
        "std",
        "prod",
        "count_nonzero",
        "all",
        "any",
        "min",
        "max",
        "argmax",
        "argmin",
    ]
    updateat_op_names = [
        "add",
        "subtract",
    ]
    preserve_shape_names = [
        "roll",
        "flip",
        "sort",
        "argsort",
    ]

    def __init__(self, np=None):
        self._np = np if np is not None else tracer.signature.python.import_("numpy", as_="np")
        self.einsum = einsum(self._np.einsum)

        for name in self.elementwise_op_names:
            setattr(self, name, partial(self.elementwise, op=getattr(self._np, name)))
        for name in self.reduce_op_names:
            setattr(self, name, partial(self.reduce, op=getattr(self._np, name)))
        for name in self.updateat_op_names:
            setattr(getattr(self, name), "at", partial(self._at, op=name))
        for name in self.preserve_shape_names:
            setattr(self, name, partial(self.preserve_shape, op=getattr(self._np, name)))

    def __getattr__(self, name):
        return getattr(self._np, name)

    def asarray(self, x):
        shape = x.shape
        x = self._np.asarray(x)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def reshape(self, x, shape):
        x = self._np.reshape(x, shape)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def transpose(self, x, axes):
        shape = tuple(x.shape[i] for i in axes)

        x = self._np.transpose(x, axes)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def broadcast_to(self, x, shape):
        x = self._np.broadcast_to(x, shape)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    @staticmethod
    def reduce(x, *, op, **kwargs):
        shape = list(x.shape)
        if "axis" not in kwargs or kwargs["axis"] is None:
            axes = list(range(x.ndim))
        elif isinstance(kwargs["axis"], int):
            axes = [kwargs["axis"]]
        else:
            axes = kwargs["axis"]

        if "keepdims" in kwargs and kwargs["keepdims"]:
            for a in axes:
                shape[a] = 1
        else:
            for a in sorted(axes, reverse=True):
                del shape[a]
        shape = tuple(shape)
        x = op(x, **kwargs)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def divmod(self, x, y):
        x = numpy.elementwise(x, y, op=self._np.divmod)
        assert isinstance(x.origin, tracer.Cast)
        return tracer.cast(x.origin.input, lambda origin: (
            tracer.signature.classical.Tensor(origin, shape=x.shape),
            tracer.signature.classical.Tensor(origin, shape=x.shape),
        ))

    @staticmethod
    def elementwise(*xs, op):
        shape = None
        for a in xs:
            if hasattr(a, "shape"):
                if shape is None:
                    shape = a.shape
                else:
                    shape2 = a.shape
                    while len(shape) < len(shape2):
                        shape = (1,) + shape
                    while len(shape2) < len(shape):
                        shape2 = (1,) + shape2
                    shape = np.maximum(shape, shape2)
        if shape is None:
            raise ValueError("elementwise operation requires at least one tensor as argument")

        x = op(*xs)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    @staticmethod
    def preserve_shape(x, *, op, **kwargs):
        shape = tuple(x.shape)
        x = op(x, **kwargs)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def take(self, x, indices, axis=None):
        if axis is None:
            shape = indices.shape
        else:
            axis2 = axis
            if axis2 < 0:
                axis2 += x.ndim
            if axis2 < 0 or axis2 >= x.ndim:
                raise ValueError(f"axis {axis} out of bounds for array of dimension {x.ndim}")
            if hasattr(indices, "shape"):
                indices_shape = tuple(indices.shape)
            else:
                indices_shape = tuple()

            shape = tuple(x.shape[:axis2]) + indices_shape + tuple(x.shape[axis2 + 1:])

        x = self._np.take(x, indices, axis=axis)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def put(self, x, indices, values):
        return tracer.signature.python.call_inplace(x, self._np.put, x, indices, values)

    def _at(self, x, indices, values, op):
        return tracer.signature.python.call_inplace(x, getattr(self._np, op).at, x, indices, values)

    def getitem(self, tensor, key):
        if isinstance(key, tuple):
            keys = key
        else:
            keys = (key,)

        if len(keys) > tensor.ndim:
            raise ValueError(f"Too many indices for tensor of dimension {tensor.ndim}")
        elif len(keys) == 0:
            raise ValueError("Empty index tuple")

        in_shape = list(tensor.shape)
        shape = []
        for k in keys:
            if isinstance(k, (np.integer, int)) or (isinstance(k, (tracer.signature.classical.Tensor, tracer.signature.classical.ConvertibleTensor)) and k.ndim == 0):
                in_shape = in_shape[1:]
            elif k == slice(None) or k == slice(None, None, -1):
                shape.append(in_shape[0])
                in_shape = in_shape[1:]
            elif k is None:
                shape.append(1)
            else:
                raise NotImplementedError(f"Key type {type(k)} not supported")
        shape = tuple(shape) + tuple(in_shape)

        x = tracer.signature.python.getitem(tensor, key)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def arange(self, n, dtype="int32"):
        shape = (n,)

        x = self._np.arange(n, dtype=dtype)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def split(self, x, indices_or_sections, axis=0):
        if axis < 0:
            axis += x.ndim
        if axis < 0 or axis >= x.ndim:
            raise ValueError(f"axis {axis} out of bounds for array of dimension {x.ndim}")

        if isinstance(indices_or_sections, int):
            section_length = x.shape[axis] // indices_or_sections
            if section_length * indices_or_sections != x.shape[0]:
                raise ValueError(
                    f"array split does not result in an equal division: "
                    f"{x.shape[axis]} % {indices_or_sections} != 0"
                )
            lengths = [section_length] * indices_or_sections
        else:
            indices = [0] + list(indices_or_sections) + [x.shape[axis]]
            lengths = [
                indices[i + 1] - indices[i]
                for i in range(len(indices) - 1)
            ]

        shapes = []
        for length in lengths:
            shape = list(x.shape)
            shape[axis] = int(length)
            shapes.append(tuple(shape))

        xs = self._np.split(x, indices_or_sections, axis=axis)
        xs = tracer.cast(xs, lambda origin: list(tracer.signature.classical.Tensor(origin, shape=shape) for shape in shapes))
        return xs

    def concatenate(self, xs, axis=0):
        if axis < 0:
            axis += xs[0].ndim
        if axis < 0 or axis >= xs[0].ndim:
            raise ValueError(f"axis {axis} out of bounds for array of dimension {xs[0].ndim}")

        shape = list(xs[0].shape)
        shape[axis] = sum(x.shape[axis] for x in xs)
        shape = tuple(shape)

        x = self._np.concatenate(xs, axis=axis)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=shape))
        return x

    def dot(self, x, y):
        if x.ndim != 1 or y.ndim != 1:
            raise ValueError(
                f"dot only supports 1D tensors, got {x.ndim}D and {y.ndim}D"
            )

        x = self._np.dot(x, y)
        x = tracer.cast(x, partial(tracer.signature.classical.Tensor, shape=tuple()))
        return x
