import struct
from abc import ABC

# from . import img_utils
import img_utils
import torch
import torch.nn.functional as F


class PyramidLatent:
    def __init__(self, n_levels: int, grayscale: bool = False):
        self.initial_image: torch.Tensor = None
        self.levels: list[list[torch.Tensor]] = [[] for _ in range(n_levels)]
        self.grayscale: bool = grayscale

    @torch.no_grad()
    def display(self, init_img: bool = True):
        if init_img:
            img_utils.display_image(
                self.initial_image, title="Initial Image", grayscale=self.grayscale
            )
        for level_idx, level in enumerate(self.levels):
            img_utils.display_row(
                [
                    img[None] for img in torch.cat(level, dim=0)
                ],  # merge the batch dimentions
                grayscale=self.grayscale,
                title=f"Level {level_idx}",
                add_resolution=True,
            )


class Pyramid(torch.nn.Module, ABC):
    def __init__(
        self,
        num_levels: int = 4,
        resolution: int = 256,
        grayscale: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__()
        self.device = device
        self.grayscale = grayscale
        self.resolution = resolution
        self.num_levels = num_levels

    @torch.no_grad()
    def gaussian_kernel(self, size, sigma, n_channels: int = 1):
        x = torch.arange(
            -size // 2 + 1, size // 2 + 1, dtype=torch.float32, device=self.device
        )
        xx, yy = torch.meshgrid(x, x, indexing="ij")
        kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
        kernel /= kernel.sum()  # Normalize to ensure energy preservation
        # Create a diagonal kernel tensor for n_channels: (n_channels, n_channels, size, size)
        # Each channel only affects its corresponding output channel
        kernel = kernel.unsqueeze(0).unsqueeze(0)

        if n_channels > 1:
            extended_kernel = torch.zeros(
                n_channels, n_channels, size, size, device=self.device
            )
            for i in range(n_channels):
                extended_kernel[i, i] = kernel[0, 0].clone()
            return extended_kernel
        return kernel

    @torch.no_grad()
    def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        # Rescale to 512x512 if needed -- image has to be a square with a power of 2 size as per https://www.ipol.im/pub/art/2014/79/
        if image.shape[1] != self.resolution or image.shape[2] != self.resolution:
            image = torch.nn.functional.interpolate(
                image,
                size=(self.resolution, self.resolution),
                mode="bilinear",
                align_corners=False,
            )
        if self.grayscale:
            image = image.mean(dim=1, keepdim=True)
        # Map to a device
        image = image.to(self.device)
        return image

    def encode(self, image: torch.Tensor, display: bool = False) -> PyramidLatent:
        raise NotImplementedError("Not implemented")

    def decode(self, latent: PyramidLatent) -> torch.Tensor:
        raise NotImplementedError("Not implemented")

    def display_kernels(self):
        raise NotImplementedError("Not implemented")


class GaussianPyramid(Pyramid):
    def __init__(
        self,
        *args,
        kernel_size: int = 3,
        kernel_sigma: float = 1.0,
        n_channels: int = 3,
        downscale_factor: int = 2,
        downscale_mode: str = "nearest",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.kernel_size = kernel_size
        self.kernel_sigma = kernel_sigma
        self.downscale_factor = downscale_factor
        self.downscale_mode = downscale_mode
        # Create kernel with appropriate number of channels
        self.n_channels = 1 if self.grayscale else n_channels
        self.gaussian_kernel = self.gaussian_kernel(
            self.kernel_size, self.kernel_sigma, self.n_channels
        )
        self._calculate_pyramid_sigmas()

    def display_kernels(self):
        img_utils.display_image(
            self.gaussian_kernel,
            title="Gaussian Kernel",
            grayscale=self.grayscale,
            add_resolution=True,
        )

    # def _calculate_pyramid_sigmas(self):
    #     # Calculate sum of squared kernel weights
    #     kernel_mult = torch.sum(self.gaussian_kernel ** 2).item()
    #     # Initial variance is 1 (input noise variance)
    #     self.sigmas = [kernel_mult**(i/2) for i in range(self.num_levels)]

    def _calculate_pyramid_sigmas(self):
        cur_img = torch.randn(
            1, self.n_channels, self.resolution, self.resolution, device=self.device
        )
        encoded = self.encode(cur_img)
        self.sigmas = [torch.std(p[0]).item() for p in encoded.levels]

    def downscale(self, image: torch.Tensor) -> torch.Tensor:
        # Calculate padding to ensure output size is exactly half of input
        padding = self.kernel_size // 2
        # if self.downscale_factor == 2:
        #     return torch.conv2d(image, self.gaussian_kernel, padding=padding, stride=2)
        # smoothed = torch.conv2d(image, self.gaussian_kernel, padding=padding)

        if self.downscale_mode == "nearest":
            return F.interpolate(
                image,
                scale_factor=1 / self.downscale_factor,
                mode="nearest",
                align_corners=None,
            )
        else:
            return F.interpolate(
                image,
                scale_factor=1 / self.downscale_factor,
                mode="bilinear",
                align_corners=False,
            )

    def upscale(self, image: torch.Tensor) -> torch.Tensor:
        # Upsample the image (introduces zeros)
        upsampled = F.interpolate(
            image,
            scale_factor=self.downscale_factor,
            mode="bilinear",
            align_corners=False,
        )
        return upsampled
        # Smooth with the same gaussian kernel
        # padding = self.kernel_size // 2
        # return torch.conv2d(upsampled, self.gaussian_kernel, padding=padding)

    def encode(self, image: torch.Tensor) -> PyramidLatent:
        latent = PyramidLatent(self.num_levels, grayscale=self.grayscale)
        image_processed = self.preprocess_image(image)
        latent.initial_image = image_processed

        latent.levels[0].append(image_processed)
        for level_idx in range(self.num_levels - 1):
            # Stride 2 is equivalent to downsampling by factor of 2
            image_processed = self.downscale(image_processed)
            latent.levels[level_idx + 1].append(image_processed)
        return latent

    def decode(self, latent: PyramidLatent) -> torch.Tensor:
        return latent.levels[0][0]


class IdentityPyramid(GaussianPyramid):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

    def downscale(self, image: torch.Tensor) -> torch.Tensor:
        return image

    def upscale(self, image: torch.Tensor) -> torch.Tensor:
        return image


class LaplacianPyramid(GaussianPyramid):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode(self, image: torch.Tensor) -> PyramidLatent:
        latent = PyramidLatent(self.num_levels, grayscale=self.grayscale)
        image_processed = self.preprocess_image(image)
        latent.initial_image = image_processed

        for level_idx in range(self.num_levels - 1):
            donwscaled = self.downscale(image_processed)
            latent.levels[level_idx].append(image_processed - self.upscale(donwscaled))
            image_processed = donwscaled

        latent.levels[-1].append(image_processed)
        return latent

    def decode(self, latent: PyramidLatent) -> torch.Tensor:
        image = latent.levels[-1][0]  # low frequency residual
        for level_idx in range(self.num_levels - 2, -1, -1):
            image = self.upscale(image) + latent.levels[level_idx][0]
        return image


"""
As far as I understood this comes from differentiating 2D gaussians.
Kth derivative gives k+1 steerable basis functions.
In practise ppl report that higher order is better.
"""


def get_steerable_kernels(n_orientations: int, resolution: int, device: torch.device):
    def factorial(n):
        if n == 0:
            return 1
        return n * factorial(n - 1)

    Q = n_orientations
    alpha = (
        (2.0 ** (Q - 1)) * (factorial(Q - 1)) / (Q * factorial(2 * (Q - 1))) ** (0.5)
    )

    # Create coordinate grid
    x = torch.linspace(-1, 1, resolution, device=device)
    y = torch.linspace(-1, 1, resolution, device=device)
    xx, yy = torch.meshgrid(x, y, indexing="ij")

    # Calculate angles for each position
    theta = torch.atan2(yy, xx)

    # Generate masks for each orientation
    masks = []
    for q in range(Q):
        # Shift and wrap angles to [-pi/2, pi/2]
        psi = (torch.pi + theta - torch.pi * q / Q) % (2 * torch.pi) - torch.pi / 2
        psi[psi > torch.pi / 2] = psi[psi > torch.pi / 2] - torch.pi
        # Create orientation mask
        mask = alpha * torch.cos(psi).pow(Q - 1)
        masks.append(mask.unsqueeze(0).unsqueeze(0))
    # Stack masks into a single tensor (n_orientations, resolution, resolution)
    return masks


class SteerablePyramidFFT(LaplacianPyramid):
    def __init__(
        self,
        *args,
        n_orientations: int = 4,
        ft_kernel_size: int = 256,
        radial_val: float = 0.5,
        twidth: float = 1.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.n_orientations = n_orientations
        self.ft_kernel_size = ft_kernel_size
        self.radial_val = radial_val
        self.twidth = twidth
        # Even though its supposed to be masks -  we are applying them as kernels because
        #   the image is coming as a result of high pass filter already.
        self.steerable_kernels_FT = get_steerable_kernels(
            self.n_orientations, self.ft_kernel_size, self.device
        )

        self.hi_mask_FT, self.lo_mask_FT = self.create_pass_filters()

    def create_pass_filters(self):
        # Inspired by https://medium.com/@itberrios6/steerable-pyramids-6bfd4d23c10d
        x = torch.linspace(-1, 1, self.ft_kernel_size, device=self.device)
        y = torch.linspace(-1, 1, self.ft_kernel_size, device=self.device)
        xx, yy = torch.meshgrid(x, y, indexing="ij")
        eps = 1e-10
        rad = torch.sqrt(xx**2 + yy**2)

        # Shift log radius (shifts by an octave if log2(radial_val) = 1)
        log_rad = torch.log2(rad + eps) - torch.log2(torch.tensor(self.radial_val))

        # Create high-pass mask
        hi_mask = torch.clamp(log_rad, -self.twidth, 0)
        hi_mask = torch.abs(torch.cos(hi_mask * torch.pi / (2 * self.twidth)))
        lo_mask = torch.sqrt(1.0 - hi_mask**2)

        return hi_mask.unsqueeze(0).unsqueeze(0), lo_mask.unsqueeze(0).unsqueeze(0)

    def apply_ft_mask(
        self, image_fft: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        rescaled_mask = torch.nn.functional.interpolate(
            mask, size=image_fft.shape[-2:], mode="bilinear", align_corners=False
        )
        return image_fft * rescaled_mask

    def downscale_fft(self, image_fft: torch.Tensor) -> torch.Tensor:
        h, w = image_fft.shape[-2:]
        return image_fft[..., h // 4 : 3 * h // 4, w // 4 : 3 * w // 4].clone()

    def upscale_fft(self, image_fft: torch.Tensor) -> torch.Tensor:
        h, w = image_fft.shape[-2:]
        target_shape = (image_fft.shape[0], image_fft.shape[1], 2 * h, 2 * w)
        upscaled = torch.zeros(target_shape, dtype=image_fft.dtype, device=self.device)
        upscaled[..., h // 2 : 3 * h // 2, w // 2 : 3 * w // 2] = image_fft
        return upscaled

    def fft(self, image: torch.Tensor) -> torch.Tensor:
        return torch.fft.fftshift(torch.fft.fft2(image))

    def ifft(self, image_fft: torch.Tensor) -> torch.Tensor:
        return torch.fft.ifft2(torch.fft.ifftshift(image_fft)).real

    def encode(self, image: torch.Tensor) -> PyramidLatent:
        latent = PyramidLatent(self.num_levels, grayscale=self.grayscale)
        image_processed = self.preprocess_image(image)
        latent.initial_image = image_processed

        # Change to full FFT
        image_fft = self.fft(image_processed)
        for level_idx in range(self.num_levels - 1):
            # Apply orientation filters and high-pass filter
            for orientation_idx in range(self.n_orientations):
                band_response = self.apply_ft_mask(image_fft, self.hi_mask_FT)
                band_response = self.apply_ft_mask(
                    band_response, self.steerable_kernels_FT[orientation_idx]
                )
                latent.levels[level_idx].append(self.ifft(band_response))

            lo_responce = self.apply_ft_mask(image_fft, self.lo_mask_FT)
            image_fft = self.downscale_fft(lo_responce)

        image_processed = self.ifft(image_fft)
        latent.levels[-1].append(image_processed)
        return latent

    def decode(self, latent: PyramidLatent) -> torch.Tensor:
        # Change to full FFT
        image_fft = self.fft(latent.levels[-1][0])
        for level_idx in range(self.num_levels - 2, -1, -1):
            lo_response = self.apply_ft_mask(
                self.upscale_fft(image_fft), self.lo_mask_FT
            )
            band_response = [
                self.apply_ft_mask(self.fft(lbr), self.steerable_kernels_FT[i])
                for i, lbr in enumerate(latent.levels[level_idx])
            ]
            hi_response = self.apply_ft_mask(sum(band_response), self.hi_mask_FT)
            image_fft = lo_response + hi_response
        return self.ifft(image_fft)

    def display_kernels(self):
        img_utils.display_row(
            [
                self.hi_mask_FT,
                self.lo_mask_FT,
                self.hi_mask_FT**2 + self.lo_mask_FT**2,
            ],
            title="High and Low pass filters in Fourier space (Last one is the sum of their squares)",
            grayscale=self.grayscale,
            add_resolution=True,
        )
        img_utils.display_row(
            self.steerable_kernels_FT
            + [
                sum(
                    [k**2 for k in self.steerable_kernels_FT]
                )  # sum(self.steerable_kernels_FT)
            ],
            title="Steerable Kernels in Fourier space (Last one is the sum of all)",
            grayscale=self.grayscale,
            add_resolution=True,
        )


"""
Here the high pass filter is proxied by the Laplacian pyramid step.
I.e. we dont have the kernel explicitly.
"""


class SteerablePyramidSpatial(SteerablePyramidFFT):
    def __init__(
        self,
        *args,
        ft_kernel_size: int = 255,  # we need it to be odd to have a center pixel;
        spatial_bandpass_kernel_size: int = 15,
        **kwargs,
    ):
        super().__init__(*args, ft_kernel_size=ft_kernel_size, **kwargs)
        self.spatial_bandpass_kernel_size = spatial_bandpass_kernel_size
        self.steerable_kernels = []
        for mask in self.steerable_kernels_FT:
            # Convert from Fourier to spatial domain
            mask_spatial_domain = self.ifft(mask.squeeze(0).squeeze(0))
            # Shift result to have center at center (for visualization and cropping)
            mask_spatial_domain = torch.fft.fftshift(mask_spatial_domain)
            # Calculate center coordinates and offsets for truncation
            center = mask_spatial_domain.shape[0] // 2
            offset = self.spatial_bandpass_kernel_size // 2
            mask_truncated = mask_spatial_domain[
                center - offset : center + offset + 1,
                center - offset : center + offset + 1,
            ]
            self.steerable_kernels.append(mask_truncated.unsqueeze(0).unsqueeze(0))

    def encode(self, image: torch.Tensor) -> PyramidLatent:
        latent = PyramidLatent(self.num_levels, grayscale=self.grayscale)
        image_processed = self.preprocess_image(image)
        latent.initial_image = image_processed

        for level_idx in range(self.num_levels - 1):
            donwscaled = self.downscale(image_processed)
            band_pass = image_processed - self.upscale(donwscaled)
            for orientation_idx in range(self.n_orientations):
                latent.levels[level_idx].append(
                    torch.conv2d(
                        band_pass,
                        self.steerable_kernels[orientation_idx],
                        padding=self.spatial_bandpass_kernel_size // 2,
                    )
                )
            image_processed = donwscaled

        latent.levels[-1].append(image_processed)
        return latent

    def decode(self, latent: PyramidLatent) -> torch.Tensor:
        image = latent.levels[-1][0]  # low frequency residual
        for level_idx in range(self.num_levels - 2, -1, -1):
            responces = [
                torch.conv2d(
                    latent.levels[level_idx][i],
                    self.steerable_kernels[i],
                    padding=self.spatial_bandpass_kernel_size // 2,
                )
                for i in range(self.n_orientations)
            ]
            image = self.upscale(image) + sum(responces)
        return image

    def display_kernels(self):
        img_utils.display_image(
            self.gaussian_kernel,
            title="Gaussian Kernel",
            grayscale=self.grayscale,
            add_resolution=True,
        )
        img_utils.display_row(
            self.steerable_kernels_FT
            + [
                sum(self.steerable_kernels_FT)
            ],  # [sum([k**2 for k in self.steerable_kernels_FT])]
            title="Steerable Kernels in Fourier space (Last one is the sum of all)",
            grayscale=self.grayscale,
            add_resolution=True,
        )
        img_utils.display_row(
            self.steerable_kernels + [sum(self.steerable_kernels)],
            title="Steerable Kernels in Spatial Domain (Last one is sum of all)",
            grayscale=self.grayscale,
            add_resolution=True,
        )
