import numpy as np
from PIL import Image

from src.datamodules.transforms import UnNormalize
from src.image_patch.img_to_patch import img_to_patch
import cv2

def visualize_patch_importance(x, importance, threshold=None, color=(255, 0, 0), patch_size=16, unnormalize=True, n_images=3):
    if unnormalize:
        unnormalize = UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        x = unnormalize(x)

    original_images = (255 * x.permute(0, 2, 3, 1).cpu().detach().numpy()).astype(np.uint8)

    if len(importance.size()) != 3:
        raise ValueError("Invalid importance shape") # Only works for aggregation to one patch

    # Color patches based on importance values
    patches_batch = img_to_patch(x, patch_size, flatten_channels=False)
    patches = (255 * patches_batch.permute(0, 1, 3, 4, 2).cpu().detach().numpy()).astype(np.uint8)
    num_patches = patches.shape[1]
    grid_size = int(num_patches ** 0.5)
    batch_size = x.shape[0]

    patch = np.zeros((patch_size, patch_size, 3), dtype=np.uint8)
    patch[:] = color

    images = []
    original = []

    for image_index in range(min(n_images, batch_size)):
        canvas = np.zeros((grid_size * patch_size, grid_size * patch_size, 3), dtype=np.uint8)
        for i in range(num_patches):
            row, col = divmod(i, grid_size)
            if threshold is not None:
                alpha = float(importance[image_index][i][0] > threshold)
            else:
                alpha = float(importance[image_index][i][0]) # 0 == fully transparent
            weighted_patch = cv2.addWeighted(patches[image_index][i], alpha, patch, 1 - alpha, 0)
            canvas[row * patch_size:(row + 1) * patch_size, col * patch_size:(col + 1) * patch_size] = weighted_patch

        images.append(Image.fromarray(canvas))
        original.append(Image.fromarray(original_images[image_index]))

    return images, original


