from itertools import product

import torch
from torchvision.transforms.functional import affine


def calc_shifts_pixel_based(image, background_value: float):
    north, west, east, south = calc_top_left_right_down_first_pixel(image, background_value)
    return list(product([-west[1] - 1, 0, image.shape[2] - east[1]], [-north[0] - 1, 0, image.shape[1] - south[0]]))


def calc_top_left_right_down_first_pixel(image, background_value: float):
    return [
        calc_first_top_not_background_pixel(image, background_value),
        calc_first_left_not_background_pixel(image, background_value),
        calc_first_right_not_background_pixel(image, background_value),
        calc_first_bottom_not_background_pixel(image, background_value)
    ]


def calc_first_top_not_background_pixel(image, background_value: float):
    for i in range(image.shape[1]):
        for j in range(image.shape[2]):
            if image[0, i, j] != background_value:
                return (i, j)


def calc_first_left_not_background_pixel(image, background_value: float):
    for j in range(image.shape[2]):
        for i in range(image.shape[1]):
            if image[0, i, j] != background_value:
                return (i, j)


def calc_first_right_not_background_pixel(image, background_value: float):
    for j in range(image.shape[2] - 1, 0, -1):
        for i in range(image.shape[1] - 1, 0, -1):
            if image[0, i, j] != background_value:
                return (i, j)


def calc_first_bottom_not_background_pixel(image, background_value: float):
    for i in range(image.shape[1] - 1, 0, -1):
        for j in range(image.shape[2] - 1, 0, -1):
            if image[0, i, j] != background_value:
                return (i, j)


def shifted_image(image: torch.Tensor, shifts: list):
    assert len(image.shape) == 3, 'image shape should be CHW'
    assert image.shape[0] <= 4, f'channels should be <= 4 - found {image.shape[0]}'
    for (dx, dy) in shifts:
        yield affine(image, 0.0, [dx, dy], 1.0, [0.0, 0.0], fill=image.min().item()), (dx, dy)


def shift_image(image: torch.Tensor, shift: tuple):
    dx, dy = shift
    return affine(image, 0.0, [dx, dy], 1.0, [0.0, 0.0], fill=image.min().item())
