import functools
import einx._src.tracer as tracer
import einx._src.util.pytree as pytree
from functools import partial
import numpy as np
from einx._src.util.functools import use_name_of

def _to_shapes_inner(axes, shapes_outer):
    shapes_inner = []
    n = None
    for axis, shape_outer in zip(axes, shapes_outer):
        shape_inner = list(shape_outer)
        if axis is not None:
            if axis < 0:
                axis += len(shape_outer)
            if axis < 0 or axis >= len(shape_outer):
                raise ValueError(f"Invalid axis index in vmap")
            if n is None:
                n = shape_inner[axis]
            elif n != shape_inner[axis]:
                raise ValueError(f"Inconsistent axis sizes in vmap")
            del shape_inner[axis]
        shapes_inner.append(shape_inner)
    return shapes_inner, n

def _to_shapes_outer(axes, shapes_inner, n):
    shapes_outer = []
    for axis, shape_inner in zip(axes, shapes_inner):
        shape_outer = list(shape_inner)
        if axis is not None:
            if axis < 0:
                axis += len(shape_inner) + 1
            if axis < 0 or axis >= len(shape_inner) + 1:
                raise ValueError(f"Invalid axis index in vmap")
            shape_outer.insert(axis, n)
        shapes_outer.append(shape_outer)
    return shapes_outer

def vmap(original_vmap):
    @use_name_of(original_vmap)
    def vmap_with_shapes(op, in_axes=0, out_axes=0):
        if isinstance(in_axes, (int, np.integer)):
            in_axes = (in_axes,)
        if isinstance(out_axes, (int, np.integer)):
            out_axes = (out_axes,)
        if not isinstance(in_axes, tuple):
            raise ValueError(f"Expected in_axes to be a tuple or int, got {type(in_axes)}")
        if not isinstance(out_axes, tuple):
            raise ValueError(f"Expected out_axes to be a tuple or int, got {type(out_axes)}")

        def vmapped_op_with_shapes(*tensors):
            if len(tensors) != len(in_axes):
                raise ValueError(f"Expected {len(in_axes)} arguments in vmapped function, got {len(tensors)}")

            # Get vmapped input shapes
            in_shapes_outer = [t.shape for t in tensors]
            in_shapes_inner, vmapped_axis_len = _to_shapes_inner(in_axes, in_shapes_outer)

            # Create inner graph
            in_tracers_inner = [tracer.signature.classical.Tensor(None, shape) for shape in in_shapes_inner]
            out_tracers_inner = op(*in_tracers_inner)
            if not isinstance(out_tracers_inner, tracer.signature.classical.Tensor) and not (isinstance(out_tracers_inner, tuple) and all(isinstance(t, tracer.signature.classical.Tensor) for t in out_tracers_inner)):
                raise ValueError(f"Expected vmapped function to return a tensor or tuple of tensor, got {pytree.map(type, out_tracers_inner)}")
            graph = tracer.Graph(in_tracers_inner, out_tracers_inner)

            # Get vmapped output shapes
            if isinstance(out_tracers_inner, tracer.Tracer):
                out_tracers_inner = [out_tracers_inner]
            out_shapes_inner = [t.shape for t in out_tracers_inner]
            out_shapes_outer = _to_shapes_outer(out_axes, out_shapes_inner, vmapped_axis_len)

            # Run vmapped function (return value is a simple tracer, since original_vmap is a simple tracer)
            tensors = original_vmap(
                graph,
                in_axes=in_axes if len(in_axes) > 1 else in_axes[0],
                out_axes=out_axes if len(out_axes) > 1 else out_axes[0],
            )(*tensors)

            # Cast return values to the expected tensor types
            if len(out_axes) == 1:
                return tracer.cast(tensors, partial(tracer.signature.classical.Tensor, shape=out_shapes_outer[0]))
            else:
                return tracer.cast(tensors, lambda origin: tuple(tracer.signature.classical.Tensor(origin, shape=shape) for shape in out_shapes_outer))

        return vmapped_op_with_shapes
    return vmap_with_shapes