import einx._src.namedtensor.stage3 as stage3
import einx._src.tracer as tracer
import functools
import einx._src.util.pytree as pytree
from einx._src.namedtensor import NamedTensor
from einx._src.util.functools import use_name_of

def _to_ord_str(i):
    if i == 0:
        return "1st"
    elif i == 1:
        return "2nd"
    elif i == 2:
        return "3rd"
    else:
        return f"{i + 1}th"

def expr_to_ftdims(expr, axisname_to_ftdim, runtime_axisname_to_ftdims=None):
    if isinstance(expr, stage3.List):
        return tuple(expr_to_ftdims(child, axisname_to_ftdim, runtime_axisname_to_ftdims) for child in expr.children)
    elif isinstance(expr, stage3.Axis):
        if expr.name in axisname_to_ftdim:
            return axisname_to_ftdim[expr.name]
        else:
            assert runtime_axisname_to_ftdims is not None
            return runtime_axisname_to_ftdims[expr.name]
    elif isinstance(expr, stage3.ConcatenatedAxis):
        raise ValueError("functorchdim does not support axis concatenation.")
    elif isinstance(expr, stage3.FlattenedAxis):
        return expr_to_ftdims(expr.inner, axisname_to_ftdim, runtime_axisname_to_ftdims)
    elif isinstance(expr, stage3.Brackets):
        return expr_to_ftdims(expr.inner, axisname_to_ftdim, runtime_axisname_to_ftdims)
    else:
        raise ValueError(f"Unexpected expression type: {type(expr)}")

def op(op, functorchdim):
    @use_name_of(op)
    def inner(*tensors, out, **kwargs):
        if not isinstance(out, (list, tuple)):
            out = [out]

        # Convert classical tensors to functorchdim tensors
        axisname_to_value = {axis.name: axis.value for tensor in tensors for axis in tensor.expr.nodes() if isinstance(axis, stage3.Axis)}
        axisname_to_ftdim = {name: functorchdim.Dim(name, value) for name, value in axisname_to_value.items()}

        fttensors = []
        for tensor in tensors:
            axisnames = [axis.name for axis in tensor.expr.nodes() if isinstance(axis, stage3.Axis)]
            ftdims = expr_to_ftdims(tensor.expr, axisname_to_ftdim)
            shape = {axisname: axisname_to_value[axisname] for axisname in axisnames}

            fttensor = tracer.signature.python.getitem(tensor.value, ftdims)
            fttensor = tracer.cast(fttensor, lambda origin: tracer.signature.functorchdim.Tensor(origin, shape=shape))
            fttensors.append(fttensor)

        # Call the operation with functorchdim tensors and check/cast return value back to classical tensors
        axis = {axis.name for tensor in tensors for axis in tensor.expr.nodes() if isinstance(axis, stage3.Axis) if stage3.is_in_brackets(axis)}
        axis = [axisname_to_ftdim[name] for name in axis]
        fttensors = op(*fttensors, axis=axis, **kwargs)
        if isinstance(fttensors, tracer.Tracer):
            # Create list of tracers
            if len(out) == 1:
                fttensors = [fttensors]
            else:
                fttensors = tracer.signature.python.assert_(
                    fttensors,
                    tracer.signature.python.builtins.isinstance(
                        fttensors,
                        tracer.signature.python.builtins.tuple,
                    ),
                    f"Expected the adapted function to return a tuple of length {len(out)}",
                )
                fttensors = tracer.signature.python.assert_(
                    fttensors,
                    tracer.signature.python.equal(
                        tracer.signature.python.builtins.len(fttensors),
                        len(out),
                    ),
                    f"Expected the adapted function to return a tuple of length {len(out)}",
                )
                fttensors = tracer.cast(fttensors, lambda origin: [tracer.signature.python.Value(origin) for _ in range(len(out))])
        elif isinstance(fttensors, tuple) and all(isinstance(t, tracer.signature.Tracer) for t in fttensors):
            # Return value already is a tuple of tracers
            if len(fttensors) != len(out):
                raise ValueError(f"Expected the adapted function to return a tuple of length {len(out)}, but got length {len(fttensors)}")
        else:
            raise ValueError(f"Expected the adapted function to return a tracer or a tuple of tracers, but got {pytree.map(type, fttensors)}")

        tensors = []
        for i, (fttensor, expr) in enumerate(zip(fttensors, out)):
            axisnames = [axis.name for axis in expr.nodes() if isinstance(axis, stage3.Axis)]
            expected_shape = {axisname: axisname_to_value[axisname] for axisname in axisnames}

            if isinstance(fttensor, tracer.signature.functorchdim.Tensor):
                # Return type is a tensor -> ensure that the static shape is correct
                if t(fttensor.shape) != t(expected_shape):
                    raise ValueError(f"Expected {_to_ord_str(i)} return value of the adapted function to be a tensor with shape {expected_shape}, but got shape {fttensor.shape}")
            else:
                # Return type is a general tracer object -> ensure that it has the correct type and that the shape is correct at runtime. Then cast to expected shape
                fttensor = tracer.signature.python.assert_(
                    fttensor,
                    tracer.signature.python.builtins.isinstance(fttensor, functorchdim.Tensor),
                    f"Expected {_to_ord_str(i)} return value of the adapted function to be a tensor",
                )
                runtime_shape = tracer.signature.python.builtins.dict(tracer.signature.python.builtins.map(
                    tracer.signature.python.function(lambda dim: (tracer.signature.python.builtins.repr(dim), dim.size)),
                    fttensor.dims,
                ))
                fttensor = tracer.signature.python.assert_(
                    fttensor,
                    tracer.signature.python.equal(runtime_shape, expected_shape),
                    f"Expected {_to_ord_str(i)} return value of the adapted function to be a tensor with shape {expected_shape}",
                )
                fttensor = tracer.cast(fttensor, lambda origin: tracer.signature.functorchdim.Tensor(origin, shape=expected_shape))

            runtime_axisname_to_ftdims = tracer.signature.python.builtins.dict(tracer.signature.python.builtins.map(
                tracer.signature.python.function(lambda dim: (tracer.signature.python.builtins.repr, dim)),
                fttensor.dims,
            ))
            ftdims = expr_to_ftdims(expr, axisname_to_ftdim, runtime_axisname_to_ftdims) # Prefer statically available ftdims

            tensor = fttensor.order(*ftdims)
            tensor = tracer.cast(tensor, lambda origin: tracer.signature.classical.Tensor(origin, shape=expr.shape))
            tensors.append(NamedTensor(tensor, expr))

        if len(tensors) == 1:
            return tensors[0]
        else:
            return tuple(tensors)

    return inner

def elementwise(op, functorchdim):
    @use_name_of(op)
    def inner(*tensors, axis, **kwargs):
        if len(axis) != 0:
            raise ValueError("Elementwise operations do not support axis arguments.")
        return op(*tensors, **kwargs)
    return globals()["op"](inner, functorchdim)

def reduce(op, functorchdim):
    @use_name_of(op)
    def inner(tensor, axis, **kwargs):
        return op(tensor, axis, **kwargs)
    return globals()["op"](inner, functorchdim)

# def get_at(functorchdim):
#     def get_at(tensor, *coords, axis):
#         coords2 = []
#         for coord in coords:
#             dim_ids = [id(dim) for dim in coord.dims]
#             coord_axis = [axis for axis in axis if id(axis) in dim_ids]
#             if len(coord_axis) == 0:
#                 coords2.append(coord)
#             elif len(coord_axis) == 1:
#                 coord = coord.order(coord_axis[0])
#                 for i in range(coord.shape[0]):
#                     coords2.append(coord[i])
#             else:
#                 raise ValueError(f"At most one marked axis can be specified for a coordinate tensor")
#         coords = coords2

#         dim_ids = [id(dim) for dim in tensor.dims]
#         tensor_axis = [axis for axis in axis if id(axis) in dim_ids]
#         if len(tensor_axis) != len(coords):
#             raise ValueError(f"Expected the number of coordinates to match the number of marked axes in the tensor, but got {len(coords)} coordinates and {len(tensor_axis)} axes")

