import torch
from torchvision.transforms.v2 import ToPILImage

from applications.app_utils import create_gaussion_dist


def apply_windows(locations, intensity, windows, original_image, gauss=None):
    h = windows.shape[-2]
    w = windows.shape[-1]
    num_of_windows = windows.shape[1]
    batch_size = windows.shape[0] if len(windows.shape) > 1 else 1
    new_images = original_image.expand(batch_size, -1, -1, -1).clone()
    base_h_idx = (locations[..., 0, :] * original_image.shape[-2]).int()
    base_h_idx = base_h_idx.clip(max=original_image.shape[-2] - h)
    base_w_idx = (locations[..., 1, :] * original_image.shape[-1]).int()
    base_w_idx = base_w_idx.clip(max=original_image.shape[-1] - w)
    windows *= intensity.view(intensity.shape + (1, 1, 1))

    for i in range(num_of_windows):
        for j in range(locations.shape[0]):
            new_images[
                j,
                :,
                base_h_idx[j, i] : base_h_idx[j, i] + h,
                base_w_idx[j, i] : base_w_idx[j, i] + w,
            ] += windows[j, i]
            # * gauss.unsqueeze(0)
    return new_images


class MiddleAdditionProcessor:
    def __init__(self, original_image, h, w, c):
        self.original_image = original_image
        self.h = h
        self.w = w
        self.c = c

    def __call__(self, data):
        data = data.reshape(-1, self.c, self.h, self.w)
        new_image = self.original_image.expand(data.shape[0], -1, -1, -1).clone()
        h = self.original_image.shape[-2]
        w = self.original_image.shape[-1]
        mid_h = h // 2
        mid_w = w // 2
        new_image[
            ...,
            :,
            mid_h : mid_h + data.shape[-2],
            mid_w : mid_w + data.shape[-1],
        ] += data
        return new_image

    @property
    def dims(self):
        return self.h * self.w * self.c


class RandomAdditionProcessor:
    def __init__(self, original_image, h, w, c, num_of_windows=10, window_manipulator=None):
        self.original_image = original_image
        self.h = h
        self.w = w
        self.c = c
        self.num_of_windows = num_of_windows
        self.window_manipulator = window_manipulator

    def __call__(self, data):
        locations = (
            data[..., : 2 * self.num_of_windows].reshape(-1, 2, self.num_of_windows) * 5
        ).sigmoid()
        intensity = data[..., 2 * self.num_of_windows : 3 * self.num_of_windows]
        # gauss = create_gaussion_dist([self.h, self.w], 0.7, 0, data.device)

        windows = data[..., 3 * self.num_of_windows :]
        windows = windows.reshape(-1, self.num_of_windows, self.c, self.h, self.w)

        if self.window_manipulator:
            windows = self.window_manipulator(windows)
        new_images = apply_windows(locations, intensity, windows, self.original_image)
        return new_images if len(data.shape) > 1 else new_images[0]

    @property
    def dims(self):
        return (
            self.h * self.w * self.c * self.num_of_windows
            + 2 * self.num_of_windows
            + self.num_of_windows
        )


class LoraWindowsProcessor:
    def __init__(self, original_image, h, w, c, rank=3, num_of_windows=10, window_manipulator=None):
        self.original_image = original_image
        self.h = h
        self.w = w
        self.c = c
        self.rank = rank
        self.num_of_windows = num_of_windows
        self.window_manipulator = window_manipulator

    def __call__(self, data):
        locations = (
            data[..., : 2 * self.num_of_windows].reshape(-1, 2, self.num_of_windows) * 5
        ).sigmoid()
        intensity = data[..., 2 * self.num_of_windows : 3 * self.num_of_windows]

        windows = data[..., 3 * self.num_of_windows :]
        lora_a = windows[..., : (self.num_of_windows * self.c) * self.rank * self.h]
        lora_a = lora_a.reshape(-1, self.num_of_windows, self.c, self.h, self.rank)
        lora_b = windows[..., (self.num_of_windows * self.c) * self.rank * self.h :]
        lora_b = lora_b.reshape(-1, self.num_of_windows, self.c, self.rank, self.w)
        windows = torch.matmul(lora_a, lora_b)
        if self.window_manipulator:
            windows = self.window_manipulator(windows)

        new_images = apply_windows(locations, intensity, windows, self.original_image)
        return new_images if len(data.shape) > 1 else new_images[0]

    @property
    def pixels_for_windows(self):
        return (self.h * self.rank + self.rank * self.w) * self.c * self.num_of_windows

    @property
    def dims(self):
        return (
            self.pixels_for_windows * self.num_of_windows + self.num_of_windows * 3
        )


class ModelProcessorWrapper:
    def __init__(self, model_processor):
        self.model_processor = model_processor

    def __call__(self, data):
        is_many = len(data.shape) > 1
        if not is_many:
            data = data.unsqueeze(0)
        with torch.no_grad():
            to_pil = ToPILImage()
            pil_images = [to_pil(img) for img in data]
            processed_image = self.model_processor(
                pil_images, return_tensors="pt"
            ).pixel_values.to(device=data.device, dtype=data.dtype)
            return processed_image if is_many else processed_image[0]
