import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision.transforms.v2 import Resize, ToTensor


class NormalizeWindowsSoftmax:
    def __init__(self, num_of_points_to_keep=100):
        self.num_of_points_to_keep = num_of_points_to_keep

    def __call__(self, images):
        is_batch = len(images.shape) == 5
        flatten_imgs = images.view(images.shape[0] if is_batch else 1, -1)
        pixels_values = flatten_imgs.softmax(dim=1) * self.num_of_points_to_keep
        return pixels_values.view(images.shape).clip(max=1, min=-1)


def extract_point_from_image(original_image, created_image):
    diff = original_image - created_image
    for i in range(10):
        print_with_noise_cancel(diff, i / 10)


def print_with_noise_cancel(image, min_noise_value):
    cloned_image = image.clone()
    cloned_image[
        torch.logical_and(
            -min_noise_value < cloned_image, cloned_image < min_noise_value
        )
    ] = 0
    plt.imshow(cloned_image.permute(1, 2, 0).numpy())
    plt.title(f"min noise value: {min_noise_value}")
    plt.show()


if __name__ == "__main__":
    original_image_path = (
        r"D:\projects\black-box-optimization\applications\images\random\bison_347.png"
    )
    created_path = r"D:\projects\black-box-optimization\model_7.png"
    original = Image.open(original_image_path)
    created = Image.open(created_path)
    resize_image = Resize((150, 150))
    to_tensor = ToTensor()
    original = resize_image(to_tensor(original))
    created = resize_image(to_tensor(created))
    extract_point_from_image(original, created)
