import numpy as np
import einx._src.tracer as tracer
from .._util import _associative_binary_to_nary
from .._util import _axis_to_axisint
from .._util import _axis_to_axistuple
from .._util import _einsum_diag_string
from functools import partial
import einx._src.adapter as adapter

class classical_from_torch:
    def __init__(self, torch):
        self._torch = torch

        elementwise_name_to_op = {
            "add": _associative_binary_to_nary(self._torch.add),
            "subtract": self._torch.subtract,
            "multiply": _associative_binary_to_nary(self._torch.multiply),
            "true_divide": self._torch.true_divide,
            "floor_divide": self._torch.floor_divide,
            "divide": self._torch.divide,
            "logaddexp": _associative_binary_to_nary(self._torch.logaddexp),
            "logical_and": _associative_binary_to_nary(self._torch.logical_and),
            "logical_or": _associative_binary_to_nary(self._torch.logical_or),
            "logical_xor": _associative_binary_to_nary(self._torch.logical_xor),
            "where": self._torch.where,
            "maximum": _associative_binary_to_nary(self._torch.maximum),
            "minimum": _associative_binary_to_nary(self._torch.minimum),
            "less": self._torch.less,
            "less_equal": self._torch.less_equal,
            "greater": self._torch.greater,
            "greater_equal": self._torch.greater_equal,
            "equal": self._torch.eq,
            "not_equal": self._torch.ne,
            "exp": self._torch.exp,
            "log": self._torch.log,
            "negative": self._torch.neg,
        }
        reduce_name_to_op = {
            "sum": self._torch.sum,
            "mean": self._torch.mean,
            "var": self._torch.var,
            "std": self._torch.std,
            "prod": self._torch.prod,
            "count_nonzero": self._torch.count_nonzero,
            "any": self._torch.any,
            "all": self._torch.all,
            "max": self._torch.amax,
            "min": self._torch.amin,
            "argmax": self._torch.argmax,
            "argmin": self._torch.argmin,
            "logsumexp": lambda x, dim=None: self._torch.logsumexp(x, dim=dim if dim is not None else tuple(range(x.ndim))),
        }
        update_at_name_to_op = {
            "set_at": partial(self._torch.index_put_, accumulate=False),
            "add_at": partial(self._torch.index_put_, accumulate=True),
            "subtract_at": lambda x, indices, updates: self._torch.index_put_(x, indices, self._torch.neg(updates), accumulate=True),
        }

        for name, op in elementwise_name_to_op.items():
            setattr(self, name, op)
        for name, op in reduce_name_to_op.items():
            setattr(self, name, partial(self._reduce, op=op))
        for name, op in update_at_name_to_op.items():
            setattr(self, name, partial(self._update_at, op=op))

    def _to_dtype(self, x):
        if isinstance(x, str):
            return getattr(self._torch, x)
        else:
            return x

    def divmod(self, x, y):
        quotient = self._torch.floor_divide(x, y)
        remainder = self._torch.remainder(x, y)
        return quotient, remainder

    def sort(self, x, *, axis=None, **kwargs):
        if axis is None:
            axis = tuple(range(x.ndim))
        return self._torch.sort(x, dim=_axis_to_axisint(axis), **kwargs)

    def argsort(self, x, *, axis=None, **kwargs):
        if axis is None:
            axis = tuple(range(x.ndim))
        return self._torch.argsort(x, dim=_axis_to_axisint(axis), **kwargs)

    def flip(self, x, *, axis=None):
        if axis is None:
            axis = tuple(range(x.ndim))
        return self._torch.flip(x, dims=_axis_to_axistuple(axis))

    def roll(self, x, *, shift, axis=None):
        if axis is None:
            axis = tuple(range(x.ndim))
        axis = _axis_to_axistuple(axis)
        if isinstance(shift, (int, np.integer)):
            shift = (shift,) * len(axis)
        if isinstance(shift, (list, tuple, np.ndarray)) and len(shift) != len(axis):
            raise ValueError(f"Expected the 'shift' argument to have a length of {len(axis)}, but got length {len(shift)}")
        return self._torch.roll(x, shifts=shift, dims=axis)

    def softmax(self, x, *, axis=None):
        return self._softmax(x, name="softmax", axis=axis)

    def log_softmax(self, x, *, axis=None):
        return self._softmax(x, name="log_softmax", axis=axis)

    def _stop_gradient(self, x):
        x.required_grad = False # TODO: check this?
        return x

    def _softmax(self, x, *, name, axis=None):
        if axis is None:
            axis = tuple(range(x.ndim))
        axis = _axis_to_axistuple(axis)
        if len(axis) == 0:
            return x
        elif len(axis) == 1:
            # Use torch's softmax directly
            return getattr(self._torch.nn.functional, name)(x, dim=axis[0])
        else:
            # Use custom implementation for multiple axes
            return getattr(adapter, f"{name}_from_classical")(self, stop_gradient=self._stop_gradient)(x, axis=axis)

    def reshape(self, x, shape):
        return self._torch.reshape(x, shape)

    def transpose(self, x, axes):
        return self._torch.permute(x, tuple(axes))

    def broadcast_to(self, x, shape):
        return self._torch.broadcast_to(x, tuple(shape))

    def diagonal(self, x, *, axes_in, axis_out):
        einsum_str = _einsum_diag_string(x.ndim, axes_in, axis_out)
        return self._torch.einsum(einsum_str, x)

    def _reduce(self, x, *, op, **kwargs):
        if isinstance(op, str):
            op = getattr(self._torch, op) # Using same reduction names as torch
        if "axis" in kwargs:
            if isinstance(kwargs["axis"], (list, np.ndarray)):
                kwargs["dim"] = tuple(kwargs["axis"])
            else:
                kwargs["dim"] = kwargs["axis"]
            del kwargs["axis"]
        if "keepdims" in kwargs:
            kwargs["keepdim"] = kwargs["keepdims"]
            del kwargs["keepdims"]
        if "dim" in kwargs and isinstance(kwargs["dim"], tuple) and len(kwargs["dim"]) == 1:
            kwargs["dim"] = kwargs["dim"][0]
        return op(x, **kwargs)

    def get_at(self, x, indices, *, axis=None):
        if axis is None:
            # Multidimensional indexing
            if not isinstance(indices, tuple):
                raise ValueError(f"Expected indices to be a tuple, but got {type(indices)}")
            if len(indices) != x.ndim:
                raise ValueError(f"Expected indices to have the same number of elements as x.ndim, but got {len(indices)} and {x.ndim}")
            indices_shapes = {i.shape for i in indices}
            if len(indices_shapes) != 1:
                raise ValueError(f"Expected all indices to have the same shape, but got {indices_shapes}")
            return self._torch.getitem(x, indices)
        else:
            # Singledimensional indexing
            if not hasattr(indices, "shape") or indices.ndim == 0:
                if axis < 0:
                    axis += x.ndim
                if axis < 0 or axis >= x.ndim:
                    raise ValueError(f"Invalid axis {axis} for array of dimension {x.ndim}")
                return self._torch.getitem(x, (slice(None),) * axis + (indices,))
            elif x.ndim == 1 and axis == 0:
                return self._torch.take(x, indices)
            else:
                raise NotImplementedError("Don't know how to express this operation in PyTorch")

    def _update_at(self, x, indices, updates, *, op):
        if x.ndim != 1:
            raise ValueError(f"Expected 1D array, but got {x.ndim}D")
        if indices.ndim != updates.ndim:
            raise ValueError(f"Expected indices and updates to have the same number of dimensions, but got {indices.ndim}D and {updates.ndim}D")
        return op(x, indices, updates)

    def arange(self, n, dtype="int32"):
        if not isinstance(n, (int, np.integer)):
            raise ValueError(f"Expected an integer for n, but got {type(n)}")
        return self._torch.arange(n, dtype=self._to_dtype(dtype))

    def split(self, x, indices_or_sections, *, axis=0):
        if isinstance(indices_or_sections, (int, np.integer)):
            return self._torch.split(x, indices_or_sections, dim=axis)
        else:
            if axis < 0:
                axis += x.ndim
            if axis < 0 or axis >= x.ndim:
                raise ValueError(f"Invalid axis {axis} for array of dimension {x.ndim}")
            sizes = (0,) + tuple(indices_or_sections) + (x.shape[axis],)
            sizes = np.diff(sizes)
            sizes = [int(s) for s in sizes]
            return self._torch.split(x, sizes, dim=axis)

    def concatenate(self, xs, *, axis=0):
        if axis < 0:
            axis += xs[0].ndim
        if axis < 0 or axis >= xs[0].ndim:
            raise ValueError(f"Invalid axis {axis} for arrays of dimension {xs[0].ndim}")
        return self._torch.cat(xs, dim=axis)

    def dot(self, x, y):
        if x.shape != y.shape:
            raise ValueError(f"Expected x and y to have the same shape, but got {x.shape} and {y.shape}")
        if x.ndim == 0:
            shape = (1,)
        else:
            shape = (np.prod(x.shape),)
        x = self.reshape(x, shape)
        y = self.reshape(y, shape)
        return self._torch.dot(x, y)