from ._util import _axis_to_axistuple

class logsumexp_from_classical:
    def __init__(self, classical, stop_gradient):
        self.classical = classical
        self.stop_gradient = stop_gradient

    def __call__(self, x, axis=None, keepdims=False):
        if axis is None:
            axes = tuple(range(x.ndim))
        else:
            axes = _axis_to_axistuple(axis)

        x_max_keepdims = self.classical.max(x, axis=axis, keepdims=True)
        if self.stop_gradient is not None:
            x_max_keepdims = self.stop_gradient(x_max_keepdims)
        x_max_dropdims = self.classical.reshape(
            x_max_keepdims,
            tuple(s for i, s in enumerate(x_max_keepdims.shape) if i not in axes),
        )

        x = self.classical.subtract(x, x_max_keepdims)
        x = self.classical.log(self.classical.sum(self.classical.exp(x), axis=axis, keepdims=keepdims))
        x = self.classical.add(x, x_max_keepdims if keepdims else x_max_dropdims)

        return x

class softmax_from_classical:
    def __init__(self, classical, stop_gradient):
        self.classical = classical
        self.stop_gradient = stop_gradient

    def __call__(self, x, axis=None):
        x_max = self.classical.max(x, axis=axis, keepdims=True)
        if self.stop_gradient is not None:
            x_max = self.stop_gradient(x_max)
        x = self.classical.subtract(x, x_max)

        return self.classical.divide(self.classical.exp(x), self.classical.sum(self.classical.exp(x), axis=axis, keepdims=True))

class log_softmax_from_classical:
    def __init__(self, classical, stop_gradient):
        self.classical = classical
        self.logsumexp = logsumexp_from_classical(classical, stop_gradient)

    def __call__(self, x, axis=None):
        return self.classical.subtract(x, self.logsumexp(x, axis=axis, keepdims=True))