import pytest
from nesim.utils.tensor_mapping import apply_mapping, find_mapping
import torch
from einops import rearrange

def test_tensor_mapping_dim_0():

    a = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])

    b = torch.tensor([[0, 1, 2], [6, 7, 8], [3, 4, 5]])

    mapping_from_a_to_b = find_mapping(a,b,dim = 0)

    b_recovered = apply_mapping(a, mapping=mapping_from_a_to_b, dim = 0)
    assert (b == b_recovered).all() == True

def test_tensor_mapping_random_tensor_on_itself():

    a = torch.randn(6, 5, 4, 3)
    for dim in range(a.ndim):
        mapping_from_a_to_a = find_mapping(a.clone(),a.clone(),dim = dim)
        assert mapping_from_a_to_a == [i for i in range(a.shape[dim])], f"dim: {dim} mapping: {mapping_from_a_to_a} does not match expected value: {[i for i in range(a.shape[dim])]}"


def test_tensor_mapping_dim_1():

    a = torch.tensor([
            [0, 1, 2], 
            [3, 4, 5], 
            [6, 7, 8]
        ]
    )

    b = torch.tensor([
    # column wise switcheroo
            [1, 0, 2], 
            [4, 3, 5], 
            [7, 6, 8]
        ]
    )

    mapping_from_a_to_b = find_mapping(a,b,dim = 1)
    b_recovered = apply_mapping(a, mapping=mapping_from_a_to_b, dim = 1)

    assert (b == b_recovered).all() == True, f"{b, b_recovered}"