from .._util import _associative_binary_to_nary
from .._util import _axis_to_axisint
from .._util import _axis_to_axistuple
from .._util import _einsum_diag_string
from functools import partial
import numpy as _np
import einx._src.adapter as adapter
from ._util import _to_tensor


class classical_from_numpy:
    def __init__(self, np):
        self._np = np
        self._to_tensor = partial(_to_tensor, np=np)
        elementwise_name_to_op = {
            "add": _associative_binary_to_nary(self._np.add),
            "subtract": self._np.subtract,
            "multiply": _associative_binary_to_nary(self._np.multiply),
            "true_divide": self._np.true_divide,
            "floor_divide": self._np.floor_divide,
            "divide": self._np.divide,
            "logaddexp": _associative_binary_to_nary(self._np.logaddexp),
            "logical_and": _associative_binary_to_nary(self._np.logical_and),
            "logical_or": _associative_binary_to_nary(self._np.logical_or),
            "logical_xor": _associative_binary_to_nary(self._np.logical_xor),
            "where": self._np.where,
            "maximum": _associative_binary_to_nary(self._np.maximum),
            "minimum": _associative_binary_to_nary(self._np.minimum),
            "less": self._np.less,
            "less_equal": self._np.less_equal,
            "greater": self._np.greater,
            "greater_equal": self._np.greater_equal,
            "equal": self._np.equal,
            "not_equal": self._np.not_equal,
            "exp": self._np.exp,
            "log": self._np.log,
            "negative": self._np.negative,
        }
        reduce_name_to_op = {
            "sum": self._np.sum,
            "mean": self._np.mean,
            "var": self._np.var,
            "std": self._np.std,
            "prod": self._np.prod,
            "count_nonzero": self._np.count_nonzero,
            "any": self._np.any,
            "all": self._np.all,
            "max": self._np.max,
            "min": self._np.min,
            "logsumexp": adapter.logsumexp_from_classical(self, stop_gradient=None),
            "argmax": self._np.argmax,
            "argmin": self._np.argmin,
        }
        update_at_name_to_op = {
            "set_at": self._np.put,
            "add_at": self._np.add.at,
            "subtract_at": self._np.subtract.at,
        }
        preserve_shape_name_to_op = {
            "softmax": adapter.softmax_from_classical(self, stop_gradient=None),
            "log_softmax": adapter.log_softmax_from_classical(self, stop_gradient=None),
            "flip": self._np.flip,
        }

        for name, op in elementwise_name_to_op.items():
            setattr(self, name, partial(self._elementwise, op=op))
        for name, op in reduce_name_to_op.items():
            setattr(self, name, partial(self._reduce, op=op))
        for name, op in update_at_name_to_op.items():
            setattr(self, name, partial(self._update_at, op=op))
        for name, op in preserve_shape_name_to_op.items():
            setattr(self, name, partial(self._preserve_shape, op=op))

    def sort(self, x, axis=None, **kwargs):
        if axis is None:
            if x.ndim == 1:
                axis = 0
            else:
                raise ValueError("When 'axis' is not specified, 'x' must be a 1D array.")
        return self._preserve_shape(x, op=self._np.sort, axis=_axis_to_axisint(axis), **kwargs)

    def argsort(self, x, axis=None, **kwargs):
        if axis is None:
            if x.ndim == 1:
                axis = 0
            else:
                raise ValueError("When 'axis' is not specified, 'x' must be a 1D array.")
        return self._preserve_shape(x, op=self._np.argsort, axis=_axis_to_axisint(axis), **kwargs)

    def roll(self, x, *, shift, axis=None):
        if axis is None:
            axis = tuple(range(x.ndim))
        axis = _axis_to_axistuple(axis)
        if isinstance(shift, (list, tuple, _np.ndarray)) and len(shift) != len(axis):
            raise ValueError(f"Expected the 'shift' argument to have a length of {len(axis)}, but got length {len(shift)}")
        return self._preserve_shape(x, op=self._np.roll, shift=shift, axis=axis)

    def reshape(self, x, shape):
        x = self._to_tensor(x)
        return self._np.reshape(x, tuple(shape))

    def transpose(self, x, axes):
        x = self._to_tensor(x)
        return self._np.transpose(x, tuple(axes))

    def broadcast_to(self, x, shape):
        x = self._to_tensor(x)
        return self._np.broadcast_to(x, tuple(shape))

    def diagonal(self, x, *, axes_in, axis_out):
        x = self._to_tensor(x)
        einsum_str = _einsum_diag_string(x.ndim, axes_in, axis_out)
        return self._np.einsum(einsum_str, x)

    def _reduce(self, x, *, op, **kwargs):
        x = self._to_tensor(x)
        if "axis" in kwargs and isinstance(kwargs["axis"], (list, _np.ndarray)):
            kwargs["axis"] = tuple(kwargs["axis"])
        return op(x, **kwargs)

    def _elementwise(self, *xs, op):
        xs = [self._to_tensor(x) for x in xs]
        return op(*xs)

    def divmod(self, x, y):
        x = self._to_tensor(x)
        y = self._to_tensor(y)
        return self._np.divmod(x, y)

    def _preserve_shape(self, x, *, op, **kwargs):
        x = self._to_tensor(x)
        if "axis" in kwargs and isinstance(kwargs["axis"], (list, _np.ndarray)):
            kwargs["axis"] = tuple(kwargs["axis"])
        return op(x, **kwargs)

    def get_at(self, x, indices, *, axis=None):
        x = self._to_tensor(x)
        if axis is None:
            # Multi-dimensional indexing
            if not isinstance(indices, tuple):
                raise ValueError(f"Expected indices to be a tuple, but got {type(indices)}")
            if len(indices) != x.ndim:
                raise ValueError(f"Expected indices to have the same number of elements as x.ndim, but got {len(indices)} and {x.ndim}")
            indices_shapes = {i.shape for i in indices}
            if len(indices_shapes) != 1:
                raise ValueError(f"Expected all indices to have the same shape, but got {indices_shapes}")
            return self._np.getitem(x, indices)
        else:
            # Single-dimensional indexing
            if not hasattr(indices, "shape") or indices.ndim == 0:
                if axis < 0:
                    axis += x.ndim
                if axis < 0 or axis >= x.ndim:
                    raise ValueError(f"Invalid axis {axis} for array of dimension {x.ndim}")
                return self._np.getitem(x, (slice(None),) * axis + (indices,))
            elif x.ndim == 1 and axis == 0:
                return self._np.take(x, indices, axis=0)
            else:
                raise NotImplementedError("Don't know how to express this operation in PyTorch")

    def _update_at(self, x, indices, updates, *, op):
        x = self._to_tensor(x)
        if x.ndim != 1:
            raise ValueError(f"Expected 1D array, but got {x.ndim}D")
        if indices.ndim != updates.ndim:
            raise ValueError(f"Expected indices and updates to have the same number of dimensions, but got {indices.ndim}D and {updates.ndim}D")
        return op(x, indices, updates)

    def arange(self, n, dtype="int32"):
        return self._np.arange(n, dtype=dtype)

    def split(self, x, indices_or_sections, axis=0):
        x = self._to_tensor(x)
        return self._np.split(x, indices_or_sections, axis=axis)

    def concatenate(self, xs, axis=0):
        xs = [self._to_tensor(x) for x in xs]
        return self._np.concatenate(xs, axis=axis)

    def dot(self, x, y):
        x = self._to_tensor(x)
        y = self._to_tensor(y)
        if x.shape != y.shape:
            raise ValueError(f"Expected x and y to have the same shape, but got {x.shape} and {y.shape}")
        if x.ndim == 0:
            shape = (1,)
        else:
            shape = (_np.prod(x.shape),)
        x = self.reshape(x, shape)
        y = self.reshape(y, shape)
        return self._np.dot(x, y)
