from __future__ import annotations
from typing import Optional, Iterator
from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from . import deep_image_prior, image_domain
import numpy as np


@torch.no_grad()
def _reset_weight(m: nn.Module) -> None:
    # - check if the current module has reset_parameters & if it's callabed called it on m
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()


class Generator(nn.Module, ABC):
    """Generator module.

    Note
    ----
    This module must have the following methods:

    reset_states()
        Reset the state of the generator. This method is called before generating
        images for each stimulus.
    forward()
        Generate images. This method is called for each iteration of the optimization
        process. The forward method has no arguments and the generated images must be
        in the range [0, 1].
    """

    @abstractmethod
    def reset_states(self) -> None:
        """Reset the state of the generator."""
        pass


class noGenerator(Generator):
    def __init__(
        self,
        image_shape: tuple[int, int],
        batch_size: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self._image_shape = image_shape
        self._batch_size = batch_size
        self._latent_list = [nn.Parameter(torch.empty([1, 3, *image_shape], **factory_kwargs)) for _ in range(batch_size)]
        self.domain = image_domain.Zero2OneImageDomain()
        self.reset_states()

    def reset_states(self, gen_seeds : Optional[list[int]] = None) -> None:
        if gen_seeds is not None:
            raise NotImplementedError(
                "noGenerator does not support setting seeds yet."
            )
        for latent in self._latent_list:
            nn.init.normal_(latent, 0.5, 0.1)

    def forward(self) -> torch.Tensor:
        image = torch.cat(self._latent_list, dim=0)
        image = torch.clamp(image, 0.0, 1.0)
        return self.domain.send(image)
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Return an iterator."""
        return iter(self._latent_list)
    

class noGeneratorInitImage(noGenerator):
    """
    No generator with arbitrary initial image.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)    

    def set_images(self, images: torch.Tensor):
        """
        Set the latent to the given images.
        images (torch.Tensor): shape (batch_size, 3, height, width)
            with vlaues in [0, 1].
        """
        b, c, h, w = images.shape
        assert b == self._batch_size
        assert h == self._image_shape[0] and w == self._image_shape[1]
        for i in range(b):
            self._latent_list[i].data = images[i].unsqueeze(0)


class DeepImagePriorGenerator(Generator):
    """
    Generator that uses the Deep Image Prior.
    The image is generated by a UNet, and its updated by
    optimizing the parameters of the UNet.
    The latent image is initialized with random values and frozen.

    Args:
        image_shape (tuple[int, int]): Shape of the generated image.
        batch_size (int): Number of images to generate.
        device (torch.device | None): Device to use for the generator.
        dtype (torch.dtype | None): Data type to use for the generator.
        use_sigmoid (bool): Whether to use sigmoid activation in the UNet.
    """
    def __init__(
        self,
        image_shape: tuple[int, int],
        batch_size: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
        use_sigmoid: bool = True,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self._image_shape = image_shape
        self._batch_size = batch_size

        # Note: we need unets for each sample
        self.unets = nn.ModuleList([
            deep_image_prior.get_net(
                input_depth=3,
                NET_TYPE="UNet",
                pad="reflection",
                upsample_mode="bilinear",
                need_sigmoid=use_sigmoid,
            ) for _ in range(batch_size)
        ])
        for unet in self.unets:
            unet.to(device)

        self.latent_images = [
            nn.Parameter(
                torch.empty(1, 3, *image_shape, **factory_kwargs)
            ) for _ in range(batch_size)
        ]

        self.domain = image_domain.Zero2OneImageDomain()
        self.reset_states()

    def reset_states(self, gen_seeds= None) -> None:
        for i in range(self._batch_size):
            if gen_seeds is not None:
                seed = gen_seeds[i]
                torch.manual_seed(seed)
                np.random.seed(seed)
            # initialize the latent image
            latent_image = self.latent_images[i]
            nn.init.uniform_(latent_image, 0.0, 1.0)

            # initialize the unet weights
            unet = self.unets[i]
            unet.apply(_reset_weight)

    def forward(self) -> torch.Tensor:
        images = [
            unet(latent_image)
            for unet, latent_image in zip(self.unets, self.latent_images)
        ]
        images = torch.cat(images, dim=0)  # (batch_size, 3, height, width)
        return self.domain.send(images)
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Return an iterator of the unet parameters."""
        return self.unets.parameters(recurse)
    
    def grad_norm(self, norm_type: int = 2) -> list[float]:
        """
        Compute the sample-wise gradient norm of the latent images.
        Returns:
            torch.Tensor: Sample-wise gradient norm of the latent images.
        """
        norms = []
        for unet in self.unets:
            n = torch.nn.utils.get_total_norm(unet.parameters())
            norms.append(n)
        return torch.stack(norms).tolist()


class FrozenGenerator(Generator):
    def __init__(
        self,
        generator_network: nn.Module,
        latent_shape: tuple[int, ...], # (batch_size, *latent_shape)
        latent_upperbound: torch.Tensor | float | None = None, # (*latent_shape)
        latent_lowerbound: torch.Tensor | float | None = None, # (*latent_shape)
        domain: image_domain.ImageDomain | None = None,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self._generator_network = generator_network
        self._domain = domain
        self._generator_network.eval()
        self._latent_list =[nn.Parameter(torch.empty([1, *latent_shape[1:]], requires_grad=True, **factory_kwargs)) for _ in range(latent_shape[0])]
        self.latent_upperbound = latent_upperbound
        self.latent_lowerbound = latent_lowerbound
        self.reset_states()

    def reset_states(self) -> None:
        for latent in self._latent_list:
            nn.init.normal_(latent, 0.0, 1.0)

    def forward(self) -> torch.Tensor:
        if self.latent_upperbound is not None and self.latent_lowerbound is not None:
            for latent in self._latent_list:
                latent.data = torch.clamp(latent, min=self.latent_lowerbound)
                latent.data = torch.clamp(latent, max=self.latent_upperbound)
        latent = torch.cat(self._latent_list, dim=0)
        image = self._generator_network(latent)
        return self._domain.send(image)
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """Return an iterator."""
        return iter(self._latent_list)
