import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from pathlib import Path
from PIL import Image, ImageOps
from typing import Sequence, Tuple, List
from scipy.ndimage import gaussian_filter
from sklearn.metrics.pairwise import cosine_distances


ALLOWED_IMAGE_EXTENSIONS = [
    'png', 'PNG',
    'jpg', 'JPG', 'jpeg', 'JPEG'
]


def get_images_paths(base_path: Path) -> List[Path]:
    paths = set()
        
    for ext in ALLOWED_IMAGE_EXTENSIONS:
        for path in base_path.glob(f'*.{ext}'):
            paths.add(path)

    paths = sorted(list(paths))

    return paths


def load_images(image_paths: Sequence[Path]) -> List[Image.Image]:
    return [Image.open(path) for path in image_paths]


def pad_images_to_max_size(image_paths: Sequence[Path], pad_color: Tuple[int, int, int] = (0, 0, 0)) -> List[Image.Image]:
    images = load_images(image_paths)

    max_width = max(image.width for image in images)
    max_height = max(image.height for image in images)

    padded_images = []
    for image in images:
        delta_w = max_width - image.width
        delta_h = max_height - image.height
        padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
        padded_image = ImageOps.expand(image, padding, pad_color)
        padded_images.append(padded_image)

    return padded_images


def create_heatmap(scores: np.ndarray, patch_size: Tuple[int, int], image_size: Tuple[int, int]):
    patches_per_row = image_size[0] // patch_size[0]
    patches_per_col = image_size[1] // patch_size[1]
    
    heatmap = np.zeros((patches_per_col, patches_per_row))
    
    for i, score in enumerate(scores):
        row = i // patches_per_row
        col = i % patches_per_row
        heatmap[row, col] = score
    
    heatmap = np.kron(heatmap, np.ones(patch_size))
    
    return heatmap


def overlay_heatmap(image: Image.Image, heatmap: np.ndarray, patch_size: int, alpha: float = 0.5, colormap: str = 'jet'):
    heatmap_normalized = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
    
    heatmap_smoothed = gaussian_filter(heatmap_normalized, sigma=patch_size / 2)

    cmap = plt.get_cmap(colormap)
    heatmap_colored = cmap(heatmap_smoothed)
    
    heatmap_image = Image.fromarray((heatmap_colored[:, :, :3] * 255).astype(np.uint8))
    
    heatmap_image = heatmap_image.resize(image.size, resample=Image.Resampling.BILINEAR)
    
    overlay_image = Image.blend(image, heatmap_image, alpha)
    
    return overlay_image


def split_image_into_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
    image_width, image_height = image.size
    
    patches_x = image_width // patch_size
    patches_y = image_height // patch_size
    
    patches = []
    
    for y in range(patches_y):
        for x in range(patches_x):
            left = x * patch_size
            upper = y * patch_size
            right = left + patch_size
            lower = upper + patch_size
            
            patch = image.crop((left, upper, right, lower))
            patches.append(patch)
    
    return patches


def scale_array(arr):
    arr_min = arr.min()
    arr_max = arr.max()
    scaled_arr = (arr - arr_min) / (arr_max - arr_min)
    return scaled_arr


def create_image_with_bias_heatmap(image: Image.Image, embedding_model: SentenceTransformer, bias_embedding: np.ndarray, patch_size: int):
    proto_patches = split_image_into_patches(image, patch_size=patch_size)

    patch_embs = [embedding_model.encode(patch) for patch in proto_patches] # type: ignore

    patch_distances = scale_array(cosine_distances(bias_embedding.reshape(1, -1), patch_embs)[0]) # type: ignore
    heatmap = create_heatmap(1 - patch_distances, patch_size=(patch_size, patch_size), image_size=image.size)

    return overlay_heatmap(image, heatmap, patch_size)
