from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
from bdpy.dl.torch import FeatureExtractor
from bdpy.dl.torch.torch import (
    FeatureExtractorHandleDetach,
    FeatureExtractorHandle,
)
from bdpy.dl.torch import models
from . import image_domain


class Encoder(nn.Module):
    """Encoder network module.

    Parameters
    ----------
    feature_network : nn.Module
        Feature network. This network should have a method `forward` that takes
        an image tensor and propagates it through the network.
    layer_names : list[str]
        Layer names to extract features from.
    domain : image_domain.ImageDomain
        Image domain to receive images.
    device : torch.device
        Device to use.
    layer_mapping: dict[str, str], optional
        Mapping from (human-readable) layer names to layer names in the model.
        If None, layers will be directly used as layer names in the model.
    """

    def __init__(
        self,
        feature_network: nn.Module,
        layer_names: list[str],
        domain: image_domain.ImageDomain,
        device: torch.device,
        layer_mapping: dict[str, str] | None = None,
    ) -> None:
        super().__init__()
        self.feature_extractor = FeatureExtractor(
            encoder=feature_network, layers=layer_names, layer_mapping=layer_mapping, 
            detach=False, device=device
        )
        self.domain = domain
        self.feature_network = self.feature_extractor._encoder

    def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]:
        """Forward pass through the encoder network.

        Parameters
        ----------
        images : torch.Tensor
            Images.

        Returns
        -------
        dict[str, torch.Tensor]
            Features indexed by the layer names.
        """
        images = self.domain.receive(images)
        return self.feature_extractor(images)
    

class TransformerEncoder(Encoder):
    """
    Encoder class for transformer models.
    Turn tuple features into tensor features.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]:
        images = self.domain.receive(images)
        features = self.feature_extractor(images)
        return tuple_feat_to_tensor_feat(features)


class VAEEncoder(Encoder):
    """
    Encoder class for VAE models. Given a concatenated tensor of mean and logvar,
    returns only the mean part.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]:
        images = self.domain.receive(images)
        features = self.feature_extractor(images)
        assert len(features) == 1, "VAEEncoder expects only one layer feature"
        for layer in features:
            feat = features[layer]
            mean, logvar = torch.chunk(feat, 2, dim=1)
            features[layer] = mean
        return features

    
class CLIPEncoder(TransformerEncoder, nn.Module):
    """
    Encoder class for CLIP models. Call forward on the vision encoder module.
    To do this, use feature extractor for CLIP models.
    For the feature postprocess, turn the tuple of features into a tensor.
    """
    def __init__(
        self,
        feature_network: nn.Module,
        layer_names: list[str],
        domain: image_domain.ImageDomain,
        device: torch.device,
        layer_mapping: dict[str, str] | None = None,
    ) -> None:
        nn.Module.__init__(self)
        # Do not call super().__init__
        # because it will make another feature extractor
        self.feature_extractor = CLIPFeatureExtractor(
            encoder=feature_network, layers=layer_names, layer_mapping=layer_mapping, 
            detach=False, device=device
        )
        self.domain = domain
        self.feature_network = self.feature_extractor._encoder

def tuple_feat_to_tensor_feat(features):
    for layer in features:
        if isinstance(features[layer], tuple):
            
            if len(features[layer]) > 1:
                # in vit model, None can be returned in the tuple
                for f in features[layer][1:]:
                    assert f is None, 'Found not-None feature in the non-first tuple element'

            features[layer] = features[layer][0]

        elif isinstance(features[layer], torch.Tensor):
            pass
        else:
            raise ValueError(f"Unexpected type of feature: {type(features[layer])}")
    return features


class CLIPFeatureExtractor(object):
    """
    FeatureExtractor class for CLIP models.
    Call get_image_features method on the image input.
    """
    def __init__(
            self, encoder: nn.Module, layers: Iterable[str],
            layer_mapping: Optional[Dict[str, str]] = None,
            device: str = 'cpu', detach: bool = True
    ):
        '''Feature extractor.

        Parameters
        ----------
        encoder : torch.nn.Module
            Network model we want to extract features from.
        layers : Iterable[str]
            List of layer names we want to extract features from.
        layer_mapping : Dict[str, str], optional
            Mapping from (human-readable) layer names to layer names in the model.
            If None, layers will be directly used as layer names in the model.
        device : str, optional
            Device name (default: 'cpu').
        detach : bool, optional
            If True, detach the feature activations from the computation graph
        '''

        self._encoder = encoder
        self.__layers = layers
        self.__layer_map = layer_mapping
        self.__detach = detach
        self.__device = device

        if detach:
            self._extractor = FeatureExtractorHandleDetach()
        else:
            self._extractor = FeatureExtractorHandle()

        self._encoder.to(self.__device)

        for layer in self.__layers:
            if self.__layer_map is not None:
                layer = self.__layer_map[layer]
            layer_object = models._parse_layer_name(self._encoder, layer)
            layer_object.register_forward_hook(self._extractor)

    def __call__(self, x: _tensor_t) -> Dict[str, np.ndarray] | Dict[str, torch.Tensor]:
        return self.run(x)
    
    def run(self, x: torch.Tensor) -> dict[str, np.ndarray] | dict[str, torch.Tensor]:
        '''Extract feature activations from the specified layers.

        Parameters
        ----------
        x : numpy.ndarray or torch.Tensor
            Input image (numpy.ndarray or torch.Tensor).

        Returns
        -------
        features : Dict[str, Union[numpy.ndarray, torch.Tensor]]
            Feature activations from the specified layers.
            Each key is the layer name and each value is the feature activation.
        '''

        self._extractor.clear()
        if not isinstance(x, torch.Tensor):
            xt = torch.tensor(x[np.newaxis], device=self.__device)
        else:
            xt = x

        self._encoder.get_image_features(xt)

        features: dict[str, torch.Tensor] = {
            layer: self._extractor.outputs[i]
            for i, layer in enumerate(self.__layers)
        }
        if not self.__detach:
            return features

        return {
            k: v.cpu().detach().numpy()
            for k, v in features.items()
        }
    
    def __del__(self):
        '''
        Remove forward hooks for the FeatureExtractor while keeping
        other forward hooks in the model.
        '''
        for layer in self.__layers:
            if self.__layer_map is not None:
                layer = self.__layer_map[layer]
            layer_object = models._parse_layer_name(self._encoder, layer)
            for key, hook in layer_object._forward_hooks.items():
                if hook == self._extractor:
                    del layer_object._forward_hooks[key]

