from typing import Optional, Union

import torch as t
from torch import nn

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_output_types import VanillaAEOutput
from auto_encoder.models.base_ae import AutoEncoderBase


class IdentityEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: t.Tensor) -> t.Tensor:
        return x


class IdentityAE(AutoEncoderBase):
    """
    An identity dictionary, i.e. the identity function.
    """

    def __init__(
        self,
        config: AutoEncoderConfig,
        medoid_initial_tensor_N: Optional[t.Tensor] = None,
        preprocess_scaling_factor: Optional[float] = None,
        device: str = device,
    ):
        super().__init__(config, medoid_initial_tensor_N, preprocess_scaling_factor, device)

        ## ENCODER
        self.encoder = IdentityEncoder()

        ## ACTIVATION FUNCTION
        ...

        # DECODER
        self.decoder.weight.data = t.eye(config.num_features)

    def encode(self, x: t.Tensor) -> t.Tensor:
        return self.encoder(x)

    def activation_fn(self, pre_activation_features: t.Tensor) -> t.Tensor:
        return pre_activation_features

    def forward(
        self,
        x: t.Tensor,
        output_intermediate_activations: bool = False,
    ) -> Union[VanillaAEOutput, t.Tensor]:

        x_center, normalising_constants = self.preprocess(x)

        pre_activation_features = self.encode(x_center)
        features = self.activation_fn(pre_activation_features)
        raw_x_hat = self.decode(features)

        x_hat = self.postprocess(raw_x_hat, normalising_constants)

        if output_intermediate_activations:
            return VanillaAEOutput(
                x_hat_BSN=x_hat,
                features_BSF=features,
                normalising_constants_BS=normalising_constants,
            )
        else:
            return x_hat


if __name__ == "__main__":
    config = AutoEncoderConfig(num_neurons=15, num_features=15)
    model = IdentityAE(config)
    x = t.rand(2, 7, 15)
    print(x)
    print(model(x))
