import torch
from typing import List
from torchtyping import TensorType

def bring_dimension_to_first_axis(tensor, desired_dimension):
    """
    Bring the desired dimension to the first axis of the tensor.
    """
    return tensor.permute(desired_dimension, *range(desired_dimension), *range(desired_dimension + 1, tensor.dim()))

def send_first_dimension_to_desired(tensor, desired_dimension):
    """
    Send the first dimension of the tensor back to the desired dimension.
    """
    return tensor.permute(*range(1, desired_dimension + 1), 0, *range(desired_dimension + 1, tensor.dim()))

@torch.no_grad()
def find_mapping(x: TensorType, y: TensorType, dim: int):
    """
    Finds the mapping that would match elements of tensor x to y along dim d
    Compatible with the apply_mapping function

    Args:
        - `x` (`TensorType`): a tensor
        - `y` (`TensorType`): another tensor, where there are matching elements but in a different order along dimension `dim`
        - `dim` (`int`): dimension along which we have to map x and y

    Raises:
        ValueError: if x and y dont have the same shape

    Returns:
        list: mapping along dimension `dim`
    """
    # Check if the shapes of x and y are the same
    if x.shape != y.shape:
        raise ValueError("x and y must have the same shape")

    if dim > 0:
        x = bring_dimension_to_first_axis(x, desired_dimension=dim)
        y = bring_dimension_to_first_axis(y, desired_dimension=dim)

    # Sort x and y along the specified dimension to find the mapping
    x_sorted, x_indices = torch.sort(x, dim=0)
    y_sorted, y_indices = torch.sort(y, dim=0)

    assert (x_sorted == y_sorted).all() == True, f"Please make sure that x and y have the same set of elements :("

    mapping = []

    for idx in range(x.shape[0]):
        for j in range(y.shape[0]):
            if (y[j] == x[idx]).all() == True and j not in mapping:
                mapping.append(j)
    
    assert len(mapping) == x.shape[0], f"len(mapping): {len(mapping)} x.shape[0]: {x.shape[0]}"
    return mapping


@torch.no_grad()
def apply_mapping(y: TensorType, mapping: List[int], dim: int):
    """Rearranges the given tensor with the provided mapping along dim d

    Args:
        - `y` (`TensorType`): tensor to be rearranged
        - `mapping` (`List[int]`): list of indices, use the find_mapping function provided above
        - `dim` (`int`): dimension along which we have to apply the mapping

    Returns:
        TensorType: mapped tensor
    """

    if dim > 0:
        y = bring_dimension_to_first_axis(y, desired_dimension=dim)
    
    y[[i for i in range(y.shape[0])]] = y[mapping]
    if dim > 0:
        y = send_first_dimension_to_desired(y, desired_dimension=dim)


    return y