import einx._src.adapter as adapter
from functools import partial

class classical_from_jax:
    def __init__(self, jax):
        self._jax = jax
        self._jnp = jax.numpy

        self._classical_from_np = adapter.classical_from_numpy(self._jnp)

        for name in adapter.ops.elementwise + adapter.ops.reduce + adapter.ops.argfind + adapter.ops.preserve_shape:
            setattr(self, name, getattr(self._classical_from_np, name))
        setattr(self, "logsumexp", partial(self._classical_from_np._reduce, op=self._jax.nn.logsumexp))
        setattr(self, "softmax", partial(self._classical_from_np._preserve_shape, op=self._jax.nn.softmax))
        setattr(self, "log_softmax", partial(self._classical_from_np._preserve_shape, op=self._jax.nn.log_softmax))

        if hasattr(jax, "at"):
            at = jax.at
        else:
            at = lambda x, indices, updates, *, op: getattr(x.at[indices], op)(updates)
        update_at_name_to_op = {
            "set_at": partial(at, op="set"),
            "add_at": partial(at, op="add"),
            "subtract_at": partial(at, op="subtract"),
        }
        for name in adapter.ops.update_at:
            setattr(self, name, partial(self._update_at, op=update_at_name_to_op[name]))

    def divmod(self, x, y):
        return self._classical_from_np.divmod(x, y)

    def reshape(self, x, shape):
        return self._classical_from_np.reshape(x, shape)

    def transpose(self, x, axes):
        return self._classical_from_np.transpose(x, axes)

    def broadcast_to(self, x, shape):
        return self._classical_from_np.broadcast_to(x, shape)

    def diagonal(self, x, axes_in, axis_out):
        return self._classical_from_np.diagonal(x, axes_in=axes_in, axis_out=axis_out)

    def get_at(self, x, indices, **kwargs):
        return self._classical_from_np.get_at(x, indices, **kwargs)

    def _update_at(self, x, indices, updates, *, op):
        x = self._classical_from_np._to_tensor(x)
        if x.ndim != 1:
            raise ValueError(f"Expected 1D array, but got {x.ndim}D")
        if indices.ndim != updates.ndim:
            raise ValueError(f"Expected indices and updates to have the same number of dimensions, but got {indices.ndim}D and {updates.ndim}D")
        return op(x, indices, updates)

    def arange(self, n, dtype="int32"):
        return self._classical_from_np.arange(n, dtype=dtype)

    def split(self, x, indices_or_sections, axis=0):
        return self._classical_from_np.split(x, indices_or_sections, axis=axis)

    def concatenate(self, xs, axis=0):
        return self._classical_from_np.concatenate(xs, axis=axis)

    def dot(self, x, y):
        return self._classical_from_np.dot(x, y)
