import numpy as np
import einx._src.tracer as tracer

class _numpy_from_mlx:
    def __init__(self, mlx):
        self._mx = mlx.core

    def _scalar_to_tensor(self, x):
        if isinstance(x, tracer.signature.classical.ConvertibleTensor):
            concrete_type = x.concrete.type
        else:
            concrete_type = type(x)

        if issubclass(concrete_type, (int, float, bool, np.integer, np.floating, np.bool_)):
            return self._mx.array(x)
        else:
            return x

    def _to_dtype(self, x):
        if isinstance(x, str):
            return getattr(self._mx, x)
        else:
            return x

    def reshape(self, x, shape):
        x = self._scalar_to_tensor(x)
        return self._mx.reshape(x, shape)

    def arange(self, n, dtype="int32"):
        if not isinstance(n, (int, np.integer)):
            raise ValueError(f"Expected an integer for n, but got {type(n)}")
        return self._mx.arange(n, dtype=self._to_dtype(dtype))

    def asarray(self, x):
        return self._mx.array(x)

    def true_divide(self, x, y):
        return self._mx.divide(x, y)

    def logical_xor(self, x, y):
        x = self._mx.not_equal(x, 0)
        y = self._mx.not_equal(y, 0)
        return self._mx.not_equal(x, y)

    def put(self, x, indices, values):
        return self._mx.setitem(x, indices, values)

    def flip(self, x, axis):
        if isinstance(axis, int):
            axis = (axis,)
        
        def _shift(axis):
            if axis < -x.ndim or axis >= x.ndim:
                raise ValueError(f"Invalid axis {axis} for array with {x.ndim} dimensions.")
            if axis < 0:
                axis += x.ndim
            return axis
        axis = tuple(_shift(a) for a in axis)

        x = self._mx.getitem(x, tuple(slice(None) if i not in axis else slice(None, None, -1) for i in range(x.ndim)))
        return x

    def count_nonzero(self, x, axis=None):
        return self._mx.sum(self._mx.not_equal(x, 0), axis=axis)

    @property
    def ndarray(self):
        return self._mx.array

    def __getattr__(self, name):
        return getattr(self._mx, name)