import torch
from einops import rearrange

# def horizontal_forward_scan(input_tensor):
#     """
#     
#     : (B, C, H, W)
#     : (B, H * W, C)
#     """
#     B, C, H, W = input_tensor.shape
#     input_tensor = input_tensor.permute(0, 2, 3, 1)  # (B, H, W, C)
#     flattened = input_tensor.reshape(B, H * W, C)  # (B, H * W, C)
#     return flattened.permute(0, 2, 1)  # (B, C, H * W)

# def horizontal_forward_scan_inv(transformed_tensor, original_shape):
#     """
#     :  horizontal_forward_scan 
#     : (B, C, H * W)
#     : (B, C, H, W)
#     """
#     B, C, H, W = original_shape
#     transformed_tensor = transformed_tensor.permute(0, 2, 1)  # (B, H * W, C)
#     recovered = transformed_tensor.view(B, H, W, C)  #  (B, H, W, C)
#     return recovered.permute(0, 3, 1, 2)  # (B, C, H, W)


# def horizontal_backward_scan(input_tensor):
#     """
#     
#     """
#     B, C, H, W = input_tensor.shape
#     input_tensor = torch.flip(input_tensor, dims=[-1])  #  (B, C, H, W)
#     input_tensor = input_tensor.permute(0, 2, 3, 1)  # (B, H, W, C)
#     flattened = input_tensor.reshape(B, H * W, C)  # (B, H * W, C)
#     return flattened.permute(0, 2, 1)  # (B, C, H * W)

# def horizontal_backward_scan_inv(transformed_tensor, original_shape):
#     """
#     :  horizontal_backward_scan 
#     """
#     B, C, H, W = original_shape
#     transformed_tensor = transformed_tensor.permute(0, 2, 1)  # (B, H * W, C)
#     recovered = transformed_tensor.view(B, H, W, C)  # (B, H, W, C)
#     recovered = recovered.permute(0, 3, 1, 2)  # (B, C, H, W)
#     return torch.flip(recovered, dims=[-1])  # 

def horizontal_forward_scan(input_tensor):
    """
    
    : (B, C, H, W)
    : (B, H * W, C)
    """
    B, C, H, W = input_tensor.shape
    input_tensor = input_tensor.permute(0, 2, 3, 1)  # (B, H, W, C)
    flattened = input_tensor.reshape(B, H * W, C)  # (B, H * W, C)
    return flattened

def horizontal_forward_scan_inv(transformed_tensor, original_shape):
    """
    :  horizontal_forward_scan 
    : (B, H * W, C)
    : (B, C, H, W)
    """
    B, C, H, W = original_shape
    recovered = transformed_tensor.view(B, H, W, C)  # (B, H, W, C)
    return recovered.permute(0, 3, 1, 2)  # (B, C, H, W)


def horizontal_backward_scan(input_tensor):
    """
    
    """
    B, C, H, W = input_tensor.shape
    input_tensor = torch.flip(input_tensor, dims=[-1])  #  (B, C, H, W)
    input_tensor = input_tensor.permute(0, 2, 3, 1)  # (B, H, W, C)
    flattened = input_tensor.reshape(B, H * W, C)  # (B, H * W, C)
    return flattened

def horizontal_backward_scan_inv(transformed_tensor, original_shape):
    """
    :  horizontal_backward_scan 
    """
    B, C, H, W = original_shape
    recovered = transformed_tensor.view(B, H, W, C)  # (B, H, W, C)
    recovered = recovered.permute(0, 3, 1, 2)  # (B, C, H, W)
    return torch.flip(recovered, dims=[-1])  # 


def vertical_forward_scan(input_tensor):
    """
    
    """
    B, C, H, W = input_tensor.shape
    input_tensor = input_tensor.permute(0, 1, 3, 2)  # (B, C, W, H)
    return input_tensor.flatten(2).permute(0, 2, 1)  # (B, W*H, C)

def vertical_forward_scan_inv(transformed_tensor, original_shape):
    """
    :  vertical_forward_scan 
    """
    B, C, H, W = original_shape
    transformed_tensor = transformed_tensor.permute(0, 2, 1)  # (B, C, W*H)
    recovered = transformed_tensor.view(B, C, W, H)  # (B, C, W, H)
    return recovered.permute(0, 1, 3, 2)  # (B, C, H, W)


def vertical_backward_scan(input_tensor):
    """
    
    """
    B, C, H, W = input_tensor.shape
    input_tensor = torch.flip(input_tensor, dims=[-2]).contiguous()  # 
    input_tensor = input_tensor.permute(0, 1, 3, 2)  # (B, C, W, H)
    return input_tensor.flatten(2).permute(0, 2, 1)  # (B, W*H, C)

def vertical_backward_scan_inv(transformed_tensor, original_shape):
    """
    :  vertical_backward_scan 
    """
    B, C, H, W = original_shape
    transformed_tensor = transformed_tensor.permute(0, 2, 1)  # (B, C, W*H)
    recovered = transformed_tensor.view(B, C, W, H)  # (B, C, W, H)
    recovered = recovered.permute(0, 1, 3, 2)  # (B, C, H, W)
    return torch.flip(recovered, dims=[-2])  # 

if __name__ == "__main__":
    # 
    B, C, H, W = 1, 1, 3, 3  # 
    input_tensor = torch.tensor([[[[1, 2, 3], 
                                [4, 5, 6], 
                                [7, 8, 9]]]], dtype=torch.float32)  # (1,1,3,3)

    #  horizontal_forward_scan
    transformed = horizontal_forward_scan(input_tensor)
    recovered = horizontal_forward_scan_inv(transformed, input_tensor.shape)
    assert torch.allclose(input_tensor, recovered), "horizontal_forward_scan_inv failed!"
    # print(transformed.shape)
    print(transformed)
    import pdb; pdb.set_trace()

    #  horizontal_backward_scan
    transformed = horizontal_backward_scan(input_tensor)
    recovered = horizontal_backward_scan_inv(transformed, input_tensor.shape)
    assert torch.allclose(input_tensor, recovered), "horizontal_backward_scan_inv failed!"

    #  vertical_forward_scan
    transformed = vertical_forward_scan(input_tensor)
    recovered = vertical_forward_scan_inv(transformed, input_tensor.shape)
    assert torch.allclose(input_tensor, recovered), "vertical_forward_scan_inv failed!"

    #  vertical_backward_scan
    transformed = vertical_backward_scan(input_tensor)
    recovered = vertical_backward_scan_inv(transformed, input_tensor.shape)
    assert torch.allclose(input_tensor, recovered), "vertical_backward_scan_inv failed!"

    print("！")