import einx._src.adapter as adapter
from functools import partial
from ._util import _numpy_from_mlx

class classical_from_mlx:
    def __init__(self, mlx):
        self._mx = mlx.core

        self._classical_from_np = adapter.classical_from_numpy(_numpy_from_mlx(mlx))

        for name in adapter.ops.elementwise + adapter.ops.reduce + adapter.ops.argfind + adapter.ops.preserve_shape + ["reshape", "transpose", "broadcast_to", "arange", "split", "concatenate", "dot", "set_at", "get_at", "divmod", "diagonal"]:
            setattr(self, name, getattr(self._classical_from_np, name))
        setattr(self, "softmax", partial(self._classical_from_np._preserve_shape, op=mlx.nn.softmax))
        setattr(self, "log_softmax", partial(self._classical_from_np._preserve_shape, op=mlx.nn.log_softmax))

        if hasattr(self._mx, "at"):
            at = self._mx.at
        else:
            at = lambda x, indices, updates, *, op: getattr(x.at[indices], op)(updates)
        update_at_name_to_op = {
            "add_at": partial(at, op="add"),
            "subtract_at": partial(at, op="subtract"),
        }
        for name in ["add_at", "subtract_at"]:
            setattr(self, name, partial(self._update_at, op=update_at_name_to_op[name]))

    def _update_at(self, x, indices, updates, *, op):
        x = self._classical_from_np._to_tensor(x)
        if x.ndim != 1:
            raise ValueError(f"Expected 1D array, but got {x.ndim}D")
        if indices.ndim != updates.ndim:
            raise ValueError(f"Expected indices and updates to have the same number of dimensions, but got {indices.ndim}D and {updates.ndim}D")
        return op(x, indices, updates)
